1use std::{
2 fmt, io,
3 net::{IpAddr, SocketAddr},
4 ops::Range,
5 str::FromStr,
6 time::Duration,
7};
8
9use bytes::Bytes;
10use rand::{seq::SliceRandom, thread_rng};
11use signal_protocol as protocol;
12use solana_pubkey::Pubkey;
13use solana_signature::Signature;
14use solana_transaction::versioned::VersionedTransaction;
15use thiserror::Error;
16use tokio::{
17 io::{AsyncReadExt, AsyncWriteExt},
18 net::{lookup_host, TcpStream},
19 time,
20};
21
22pub const DEFAULT_SIGNAL_PORT: u16 = 9000;
23pub const DEFAULT_SIGNAL_REGION: SignalRegion = SignalRegion::Ewr;
24
25const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_MAX_FRAME_BODY_LEN: usize = 1024 * 1024;
27const DEFAULT_RECONNECT_BACKOFF: Duration = Duration::from_millis(100);
28const SIGNAL_DNS_SUFFIX: &str = "signals.helius-rpc.com";
29
30pub use protocol::{
31 ProgramFilter, ProgramFilterEntry, ProgramFilterSet, ProgramIdBytes,
32 DEFAULT_PROGRAM_FILTER_SET, PROGRAM_FILTERS,
33};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SignalRegion {
37 Ewr,
38 Ams,
39}
40
41impl SignalRegion {
42 pub fn as_str(self) -> &'static str {
43 match self {
44 Self::Ewr => "ewr",
45 Self::Ams => "ams",
46 }
47 }
48
49 pub fn hostname(self) -> String {
50 format!("{}.{}", self.as_str(), SIGNAL_DNS_SUFFIX)
51 }
52}
53
54impl Default for SignalRegion {
55 fn default() -> Self {
56 DEFAULT_SIGNAL_REGION
57 }
58}
59
60impl fmt::Display for SignalRegion {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.write_str(self.as_str())
63 }
64}
65
66impl FromStr for SignalRegion {
67 type Err = ParseSignalRegionError;
68
69 fn from_str(value: &str) -> Result<Self, Self::Err> {
70 match value.to_ascii_lowercase().as_str() {
71 "ewr" => Ok(Self::Ewr),
72 "ams" => Ok(Self::Ams),
73 _ => Err(ParseSignalRegionError {
74 input: value.to_string(),
75 }),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Error, PartialEq, Eq)]
81#[error("invalid signal region {input:?}; supported regions are ewr and ams")]
82pub struct ParseSignalRegionError {
83 input: String,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct SignalEndpoint {
88 kind: SignalEndpointKind,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92enum SignalEndpointKind {
93 Region(SignalRegion),
94 Addr(SocketAddr),
95 #[cfg(test)]
96 StaticAddrs(Vec<SocketAddr>),
97}
98
99impl SignalEndpoint {
100 pub fn region(region: SignalRegion) -> Self {
101 Self {
102 kind: SignalEndpointKind::Region(region),
103 }
104 }
105
106 pub fn addr(addr: SocketAddr) -> Self {
107 Self {
108 kind: SignalEndpointKind::Addr(addr),
109 }
110 }
111
112 pub fn ip(ip: IpAddr) -> Self {
113 Self::addr(SocketAddr::new(ip, DEFAULT_SIGNAL_PORT))
114 }
115
116 async fn resolve_addrs(&self) -> Result<Vec<SocketAddr>, SignalClientError> {
117 match &self.kind {
118 SignalEndpointKind::Region(region) => {
119 let hostname = region.hostname();
120 let mut addrs = Vec::new();
121 for addr in lookup_host((hostname.as_str(), DEFAULT_SIGNAL_PORT)).await? {
122 if !addrs.contains(&addr) {
123 addrs.push(addr);
124 }
125 }
126 if addrs.is_empty() {
127 return Err(SignalClientError::NoResolvedAddresses(self.to_string()));
128 }
129 shuffle_addrs(&mut addrs);
130 Ok(addrs)
131 }
132 SignalEndpointKind::Addr(addr) => Ok(vec![*addr]),
133 #[cfg(test)]
134 SignalEndpointKind::StaticAddrs(addrs) => {
135 if addrs.is_empty() {
136 return Err(SignalClientError::NoResolvedAddresses(self.to_string()));
137 }
138 Ok(addrs.clone())
139 }
140 }
141 }
142}
143
144impl Default for SignalEndpoint {
145 fn default() -> Self {
146 Self::region(DEFAULT_SIGNAL_REGION)
147 }
148}
149
150impl From<SignalRegion> for SignalEndpoint {
151 fn from(region: SignalRegion) -> Self {
152 Self::region(region)
153 }
154}
155
156impl From<SocketAddr> for SignalEndpoint {
157 fn from(addr: SocketAddr) -> Self {
158 Self::addr(addr)
159 }
160}
161
162impl From<IpAddr> for SignalEndpoint {
163 fn from(ip: IpAddr) -> Self {
164 Self::ip(ip)
165 }
166}
167
168impl fmt::Display for SignalEndpoint {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 match &self.kind {
171 SignalEndpointKind::Region(region) => f.write_str(region.as_str()),
172 SignalEndpointKind::Addr(addr) => write!(f, "{addr}"),
173 #[cfg(test)]
174 SignalEndpointKind::StaticAddrs(addrs) => write!(f, "{addrs:?}"),
175 }
176 }
177}
178
179impl FromStr for SignalEndpoint {
180 type Err = ParseSignalEndpointError;
181
182 fn from_str(value: &str) -> Result<Self, Self::Err> {
183 if let Ok(region) = value.parse::<SignalRegion>() {
184 return Ok(Self::region(region));
185 }
186 if let Ok(addr) = value.parse::<SocketAddr>() {
187 return Ok(Self::addr(addr));
188 }
189 if let Ok(ip) = value.parse::<IpAddr>() {
190 return Ok(Self::ip(ip));
191 }
192 Err(ParseSignalEndpointError {
193 input: value.to_string(),
194 })
195 }
196}
197
198#[derive(Debug, Clone, Error, PartialEq, Eq)]
199#[error("invalid signal endpoint {input:?}; use ewr, ams, IP, or IP:port")]
200pub struct ParseSignalEndpointError {
201 input: String,
202}
203
204fn shuffle_addrs(addrs: &mut [SocketAddr]) {
205 addrs.shuffle(&mut thread_rng());
206}
207
208#[derive(Debug, Clone)]
209pub struct SignalClientConfig {
210 pub endpoint: SignalEndpoint,
211 pub program_id: [u8; 32],
212 pub connect_timeout: Duration,
213 pub tcp_nodelay: bool,
214 pub max_frame_body_len: usize,
215}
216
217impl SignalClientConfig {
218 pub fn new(endpoint: impl Into<SignalEndpoint>, program_id: Pubkey) -> Self {
219 Self::new_from_program_id_bytes(endpoint, program_id.to_bytes())
220 }
221
222 pub fn new_for_program_filter(
223 endpoint: impl Into<SignalEndpoint>,
224 program_filter: ProgramFilter,
225 ) -> Self {
226 Self::new_from_program_id_bytes(endpoint, program_filter.bytes())
227 }
228
229 pub fn new_from_program_id_bytes(
230 endpoint: impl Into<SignalEndpoint>,
231 program_id: [u8; 32],
232 ) -> Self {
233 Self {
234 endpoint: endpoint.into(),
235 program_id,
236 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
237 tcp_nodelay: true,
238 max_frame_body_len: DEFAULT_MAX_FRAME_BODY_LEN,
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
244pub struct SignalTransaction {
245 pub slot: u64,
246 pub created_at_unix_nanos: i64,
247 pub signature: Signature,
248 pub transaction: VersionedTransaction,
249}
250
251#[derive(Debug, Clone)]
252pub struct RawSignalTransaction {
253 pub slot: u64,
254 pub created_at_unix_nanos: i64,
255 pub signature: Signature,
256 pub serialized_transaction: Bytes,
257}
258
259impl RawSignalTransaction {
260 pub fn into_transaction(self) -> Result<SignalTransaction, SignalClientError> {
261 let transaction = bincode::deserialize(&self.serialized_transaction)
262 .map_err(SignalClientError::DeserializeTransaction)?;
263 Ok(SignalTransaction {
264 slot: self.slot,
265 created_at_unix_nanos: self.created_at_unix_nanos,
266 signature: self.signature,
267 transaction,
268 })
269 }
270}
271
272#[derive(Debug, Clone, Copy)]
273pub struct BorrowedRawSignalTransaction<'a> {
274 pub slot: u64,
275 pub created_at_unix_nanos: i64,
276 pub signature: Signature,
277 pub serialized_transaction: &'a [u8],
278}
279
280impl BorrowedRawSignalTransaction<'_> {
281 pub fn deserialize_transaction(&self) -> Result<VersionedTransaction, SignalClientError> {
282 bincode::deserialize(self.serialized_transaction)
283 .map_err(SignalClientError::DeserializeTransaction)
284 }
285
286 pub fn to_owned(self) -> RawSignalTransaction {
287 RawSignalTransaction {
288 slot: self.slot,
289 created_at_unix_nanos: self.created_at_unix_nanos,
290 signature: self.signature,
291 serialized_transaction: Bytes::copy_from_slice(self.serialized_transaction),
292 }
293 }
294}
295
296#[derive(Debug, Error)]
297pub enum SignalClientError {
298 #[error("signal client builder requires a program filter or program id")]
299 MissingProgramFilter,
300 #[error("connect timeout after {0:?}")]
301 ConnectTimeout(Duration),
302 #[error("no addresses resolved for signal endpoint {0}")]
303 NoResolvedAddresses(String),
304 #[error("signal stream closed by peer")]
305 StreamClosed,
306 #[error("io error: {0}")]
307 Io(#[from] io::Error),
308 #[error("invalid frame body length {len}; max configured length is {max}")]
309 FrameTooLarge { len: usize, max: usize },
310 #[error("invalid zero-length frame body")]
311 EmptyFrame,
312 #[error("unexpected signal opcode {0:#04x}")]
313 UnexpectedOpcode(u8),
314 #[error("failed to deserialize VersionedTransaction: {0}")]
315 DeserializeTransaction(bincode::Error),
316 #[error("protocol error: {0}")]
317 Protocol(#[from] protocol::WireError),
318}
319
320pub struct SignalStream {
321 stream: TcpStream,
322 max_frame_body_len: usize,
323 header: [u8; protocol::FRAME_HEADER_LEN],
324 read_buf: Vec<u8>,
325 read_pos: usize,
326 payload_range: Range<usize>,
327}
328
329impl SignalStream {
330 pub async fn connect(config: SignalClientConfig) -> Result<Self, SignalClientError> {
331 let addrs = config.endpoint.resolve_addrs().await?;
332 let mut last_error = None;
333 for addr in addrs {
334 match Self::connect_addr(&config, addr).await {
335 Ok(stream) => return Ok(stream),
336 Err(err) => last_error = Some(err),
337 }
338 }
339
340 Err(last_error
341 .unwrap_or_else(|| SignalClientError::NoResolvedAddresses(config.endpoint.to_string())))
342 }
343
344 async fn connect_addr(
345 config: &SignalClientConfig,
346 addr: SocketAddr,
347 ) -> Result<Self, SignalClientError> {
348 let stream = match time::timeout(config.connect_timeout, TcpStream::connect(addr)).await {
349 Ok(result) => result?,
350 Err(_) => return Err(SignalClientError::ConnectTimeout(config.connect_timeout)),
351 };
352 stream.set_nodelay(config.tcp_nodelay)?;
353
354 let mut client = Self {
355 stream,
356 max_frame_body_len: config.max_frame_body_len,
357 header: [0; protocol::FRAME_HEADER_LEN],
358 read_buf: Vec::with_capacity(64 * 1024),
359 read_pos: 0,
360 payload_range: 0..0,
361 };
362 client.write_start_stream(config.program_id).await?;
363 Ok(client)
364 }
365
366 pub async fn next_transaction(&mut self) -> Result<SignalTransaction, SignalClientError> {
367 self.read_next_tx_payload().await?;
368
369 let payload = self.current_payload();
370 let metadata = protocol::parse_tx_payload_metadata(payload)?;
371 let transaction = bincode::deserialize(&payload[protocol::TX_METADATA_LEN..])
372 .map_err(SignalClientError::DeserializeTransaction)?;
373
374 Ok(SignalTransaction {
375 slot: metadata.slot,
376 created_at_unix_nanos: metadata.created_at_ns,
377 signature: Signature::from(metadata.signature),
378 transaction,
379 })
380 }
381
382 pub async fn next_raw(&mut self) -> Result<RawSignalTransaction, SignalClientError> {
383 Ok(self.next_raw_borrowed().await?.to_owned())
384 }
385
386 pub async fn next_raw_borrowed(
387 &mut self,
388 ) -> Result<BorrowedRawSignalTransaction<'_>, SignalClientError> {
389 self.read_next_tx_payload().await?;
390 self.borrow_raw_payload()
391 }
392
393 fn borrow_raw_payload(&self) -> Result<BorrowedRawSignalTransaction<'_>, SignalClientError> {
394 let payload = self.current_payload();
395 let metadata = protocol::parse_tx_payload_metadata(payload)?;
396 Ok(BorrowedRawSignalTransaction {
397 slot: metadata.slot,
398 created_at_unix_nanos: metadata.created_at_ns,
399 signature: Signature::from(metadata.signature),
400 serialized_transaction: &payload[protocol::TX_METADATA_LEN..],
401 })
402 }
403
404 pub fn into_inner(self) -> TcpStream {
405 self.stream
406 }
407
408 async fn write_start_stream(&mut self, program_id: [u8; 32]) -> Result<(), SignalClientError> {
409 let frame = protocol::encode_start_stream_frame(&program_id);
410 self.stream.write_all(&frame).await?;
411 Ok(())
412 }
413
414 async fn read_next_tx_payload(&mut self) -> Result<(), SignalClientError> {
415 self.read_next_frame().await?;
416
417 let opcode = self.header[4];
418 if opcode != protocol::MSG_TX {
419 return Err(SignalClientError::UnexpectedOpcode(opcode));
420 }
421 let payload = self.current_payload();
422 if payload.len() < protocol::TX_METADATA_LEN {
423 protocol::parse_tx_payload_metadata(payload)?;
424 }
425
426 Ok(())
427 }
428
429 async fn read_next_frame(&mut self) -> Result<(), SignalClientError> {
430 self.payload_range = 0..0;
431 self.compact_read_buffer();
432 self.ensure_buffered(protocol::FRAME_HEADER_LEN).await?;
433
434 let header_start = self.read_pos;
435 let header_end = header_start + protocol::FRAME_HEADER_LEN;
436 self.header
437 .copy_from_slice(&self.read_buf[header_start..header_end]);
438
439 let body_len = u32::from_le_bytes(self.header[..4].try_into().unwrap()) as usize;
440 if body_len == 0 {
441 return Err(SignalClientError::EmptyFrame);
442 }
443 if body_len > self.max_frame_body_len {
444 return Err(SignalClientError::FrameTooLarge {
445 len: body_len,
446 max: self.max_frame_body_len,
447 });
448 }
449
450 let payload_len = body_len - 1;
451 let frame_len = protocol::FRAME_HEADER_LEN + payload_len;
452 self.ensure_buffered(frame_len).await?;
453
454 let payload_start = header_end;
455 let payload_end = payload_start + payload_len;
456 self.payload_range = payload_start..payload_end;
457 self.read_pos = payload_end;
458 Ok(())
459 }
460
461 async fn ensure_buffered(&mut self, needed: usize) -> Result<(), SignalClientError> {
462 while self.read_buf.len().saturating_sub(self.read_pos) < needed {
463 let read = self.stream.read_buf(&mut self.read_buf).await?;
464 if read == 0 {
465 return Err(SignalClientError::StreamClosed);
466 }
467 }
468 Ok(())
469 }
470
471 fn compact_read_buffer(&mut self) {
472 if self.read_pos == 0 {
473 return;
474 }
475 if self.read_pos >= self.read_buf.len() {
476 self.read_buf.clear();
477 self.read_pos = 0;
478 return;
479 }
480
481 let remaining = self.read_buf.len() - self.read_pos;
482 self.read_buf.copy_within(self.read_pos.., 0);
483 self.read_buf.truncate(remaining);
484 self.read_pos = 0;
485 }
486
487 fn current_payload(&self) -> &[u8] {
488 &self.read_buf[self.payload_range.clone()]
489 }
490}
491
492#[derive(Debug, Clone)]
493pub struct SignalClientBuilder {
494 endpoint: SignalEndpoint,
495 program_id: Option<[u8; 32]>,
496 connect_timeout: Duration,
497 tcp_nodelay: bool,
498 max_frame_body_len: usize,
499 reconnect: bool,
500 reconnect_backoff: Duration,
501}
502
503impl Default for SignalClientBuilder {
504 fn default() -> Self {
505 Self {
506 endpoint: SignalEndpoint::default(),
507 program_id: None,
508 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
509 tcp_nodelay: true,
510 max_frame_body_len: DEFAULT_MAX_FRAME_BODY_LEN,
511 reconnect: true,
512 reconnect_backoff: DEFAULT_RECONNECT_BACKOFF,
513 }
514 }
515}
516
517impl SignalClientBuilder {
518 pub fn endpoint(mut self, endpoint: impl Into<SignalEndpoint>) -> Self {
519 self.endpoint = endpoint.into();
520 self
521 }
522
523 pub fn region(self, region: SignalRegion) -> Self {
524 self.endpoint(region)
525 }
526
527 pub fn addr(self, addr: SocketAddr) -> Self {
528 self.endpoint(addr)
529 }
530
531 pub fn ip(self, ip: IpAddr) -> Self {
532 self.endpoint(ip)
533 }
534
535 pub fn program_id(mut self, program_id: Pubkey) -> Self {
536 self.program_id = Some(program_id.to_bytes());
537 self
538 }
539
540 pub fn program_filter(mut self, program_filter: ProgramFilter) -> Self {
541 self.program_id = Some(program_filter.bytes());
542 self
543 }
544
545 pub fn program_id_bytes(mut self, program_id: [u8; 32]) -> Self {
546 self.program_id = Some(program_id);
547 self
548 }
549
550 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
551 self.connect_timeout = timeout;
552 self
553 }
554
555 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
556 self.tcp_nodelay = enabled;
557 self
558 }
559
560 pub fn max_frame_body_len(mut self, len: usize) -> Self {
561 self.max_frame_body_len = len;
562 self
563 }
564
565 pub fn reconnect(mut self, enabled: bool) -> Self {
566 self.reconnect = enabled;
567 self
568 }
569
570 pub fn reconnect_backoff(mut self, backoff: Duration) -> Self {
571 self.reconnect_backoff = backoff;
572 self
573 }
574
575 pub async fn connect(self) -> Result<SignalClient, SignalClientError> {
576 SignalClient::connect(self).await
577 }
578
579 fn stream_config(&self) -> Result<SignalClientConfig, SignalClientError> {
580 let program_id = self
581 .program_id
582 .ok_or(SignalClientError::MissingProgramFilter)?;
583 Ok(SignalClientConfig {
584 endpoint: self.endpoint.clone(),
585 program_id,
586 connect_timeout: self.connect_timeout,
587 tcp_nodelay: self.tcp_nodelay,
588 max_frame_body_len: self.max_frame_body_len,
589 })
590 }
591}
592
593pub struct SignalClient {
594 stream_config: SignalClientConfig,
595 reconnect: bool,
596 reconnect_backoff: Duration,
597 stream: SignalStream,
598}
599
600impl SignalClient {
601 pub fn builder() -> SignalClientBuilder {
602 SignalClientBuilder::default()
603 }
604
605 pub async fn connect(builder: SignalClientBuilder) -> Result<Self, SignalClientError> {
606 let stream_config = builder.stream_config()?;
607 let stream = SignalStream::connect(stream_config.clone()).await?;
608 Ok(Self {
609 stream_config,
610 reconnect: builder.reconnect,
611 reconnect_backoff: builder.reconnect_backoff,
612 stream,
613 })
614 }
615
616 pub async fn next_raw(&mut self) -> Result<RawSignalTransaction, SignalClientError> {
617 loop {
618 match self.stream.next_raw().await {
619 Ok(tx) => return Ok(tx),
620 Err(err) if self.should_reconnect(&err) => self.reconnect().await?,
621 Err(err) => return Err(err),
622 }
623 }
624 }
625
626 pub async fn next_transaction(&mut self) -> Result<SignalTransaction, SignalClientError> {
627 self.next_raw().await?.into_transaction()
628 }
629
630 async fn reconnect(&mut self) -> Result<(), SignalClientError> {
631 loop {
632 time::sleep(self.reconnect_backoff).await;
633 match SignalStream::connect(self.stream_config.clone()).await {
634 Ok(stream) => {
635 self.stream = stream;
636 return Ok(());
637 }
638 Err(err) if self.should_reconnect(&err) => continue,
639 Err(err) => return Err(err),
640 }
641 }
642 }
643
644 fn should_reconnect(&self, err: &SignalClientError) -> bool {
645 self.reconnect
646 && matches!(
647 err,
648 SignalClientError::StreamClosed
649 | SignalClientError::Io(_)
650 | SignalClientError::ConnectTimeout(_)
651 | SignalClientError::NoResolvedAddresses(_)
652 )
653 }
654
655 pub fn into_stream(self) -> SignalStream {
656 self.stream
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use solana_hash::Hash;
664 use solana_instruction::Instruction;
665 use solana_message::{Message, VersionedMessage};
666 use tokio::net::TcpListener;
667
668 #[test]
669 fn parses_region_and_endpoint_values() {
670 assert_eq!("EWR".parse::<SignalRegion>().unwrap(), SignalRegion::Ewr);
671 assert_eq!("ams".parse::<SignalRegion>().unwrap(), SignalRegion::Ams);
672 assert_eq!(SignalRegion::Ewr.hostname(), "ewr.signals.helius-rpc.com");
673
674 assert_eq!(
675 "ewr".parse::<SignalEndpoint>().unwrap(),
676 SignalEndpoint::region(SignalRegion::Ewr)
677 );
678 assert_eq!(
679 "127.0.0.1".parse::<SignalEndpoint>().unwrap(),
680 SignalEndpoint::ip("127.0.0.1".parse::<IpAddr>().unwrap())
681 );
682 assert_eq!(
683 "127.0.0.1:9001".parse::<SignalEndpoint>().unwrap(),
684 SignalEndpoint::addr("127.0.0.1:9001".parse::<SocketAddr>().unwrap())
685 );
686 }
687
688 #[tokio::test]
689 async fn builder_requires_program_filter() {
690 let result = SignalClient::builder()
691 .addr("127.0.0.1:9000".parse::<SocketAddr>().unwrap())
692 .connect()
693 .await;
694
695 assert!(matches!(
696 result,
697 Err(SignalClientError::MissingProgramFilter)
698 ));
699 }
700
701 #[tokio::test]
702 async fn connects_and_decodes_transaction() {
703 let program_id = Pubkey::from([2u8; 32]);
704 let signature = Signature::from([7u8; 64]);
705 let transaction = test_transaction(program_id, signature);
706 let serialized = bincode::serialize(&transaction).unwrap();
707 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
708 let addr = listener.local_addr().unwrap();
709
710 let server = tokio::spawn(async move {
711 let (mut socket, _) = listener.accept().await.unwrap();
712 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
713 socket.read_exact(&mut start).await.unwrap();
714 assert_eq!(
715 u32::from_le_bytes(start[..4].try_into().unwrap()),
716 protocol::START_STREAM_FRAME_BODY_LEN as u32
717 );
718 assert_eq!(start[4], protocol::MSG_START_STREAM);
719 assert_eq!(&start[5..], &program_id.to_bytes());
720
721 let frame = encode_tx_frame(42, 99, signature, &serialized);
722 socket.write_all(&frame).await.unwrap();
723 });
724
725 let config = SignalClientConfig::new(addr, program_id);
726 let mut stream = SignalStream::connect(config).await.unwrap();
727 let received = stream.next_transaction().await.unwrap();
728
729 assert_eq!(received.slot, 42);
730 assert_eq!(received.created_at_unix_nanos, 99);
731 assert_eq!(received.signature, signature);
732 assert_eq!(received.transaction.signatures[0], signature);
733 server.await.unwrap();
734 }
735
736 #[tokio::test]
737 async fn connect_falls_back_to_next_available_addr() {
738 let program_id = Pubkey::from([4u8; 32]);
739 let signature = Signature::from([10u8; 64]);
740 let transaction = test_transaction(program_id, signature);
741 let serialized = bincode::serialize(&transaction).unwrap();
742
743 let bad_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
744 let bad_addr = bad_listener.local_addr().unwrap();
745 drop(bad_listener);
746
747 let good_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
748 let good_addr = good_listener.local_addr().unwrap();
749 tokio::spawn(async move {
750 let (mut socket, _) = good_listener.accept().await.unwrap();
751 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
752 socket.read_exact(&mut start).await.unwrap();
753 let frame = encode_tx_frame(46, 103, signature, &serialized);
754 socket.write_all(&frame).await.unwrap();
755 });
756
757 let endpoint = SignalEndpoint {
758 kind: SignalEndpointKind::StaticAddrs(vec![bad_addr, good_addr]),
759 };
760 let mut config = SignalClientConfig::new(endpoint, program_id);
761 config.connect_timeout = Duration::from_secs(1);
762 let mut stream = SignalStream::connect(config).await.unwrap();
763 let received = stream.next_raw().await.unwrap();
764
765 assert_eq!(received.slot, 46);
766 assert_eq!(received.created_at_unix_nanos, 103);
767 assert_eq!(received.signature, signature);
768 }
769
770 #[tokio::test]
771 async fn raw_mode_preserves_serialized_transaction_bytes() {
772 let program_id = Pubkey::from([3u8; 32]);
773 let signature = Signature::from([8u8; 64]);
774 let transaction = test_transaction(program_id, signature);
775 let serialized = bincode::serialize(&transaction).unwrap();
776 let server_serialized = serialized.clone();
777 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
778 let addr = listener.local_addr().unwrap();
779
780 tokio::spawn(async move {
781 let (mut socket, _) = listener.accept().await.unwrap();
782 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
783 socket.read_exact(&mut start).await.unwrap();
784 let frame = encode_tx_frame(43, 100, signature, &server_serialized);
785 socket.write_all(&frame).await.unwrap();
786 });
787
788 let config = SignalClientConfig::new(addr, program_id);
789 let mut stream = SignalStream::connect(config).await.unwrap();
790 let received = stream.next_raw().await.unwrap();
791
792 assert_eq!(received.slot, 43);
793 assert_eq!(received.created_at_unix_nanos, 100);
794 assert_eq!(received.signature, signature);
795 assert_eq!(&received.serialized_transaction[..], serialized.as_slice());
796 }
797
798 #[tokio::test]
799 async fn borrowed_raw_mode_avoids_copying_serialized_transaction() {
800 let program_id = Pubkey::from([6u8; 32]);
801 let signature = Signature::from([9u8; 64]);
802 let transaction = test_transaction(program_id, signature);
803 let serialized = bincode::serialize(&transaction).unwrap();
804 let server_serialized = serialized.clone();
805 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
806 let addr = listener.local_addr().unwrap();
807
808 tokio::spawn(async move {
809 let (mut socket, _) = listener.accept().await.unwrap();
810 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
811 socket.read_exact(&mut start).await.unwrap();
812 let frame = encode_tx_frame(44, 101, signature, &server_serialized);
813 socket.write_all(&frame).await.unwrap();
814 });
815
816 let config = SignalClientConfig::new(addr, program_id);
817 let mut stream = SignalStream::connect(config).await.unwrap();
818 let received = stream.next_raw_borrowed().await.unwrap();
819
820 assert_eq!(received.slot, 44);
821 assert_eq!(received.created_at_unix_nanos, 101);
822 assert_eq!(received.signature, signature);
823 assert_eq!(received.serialized_transaction, serialized.as_slice());
824 }
825
826 #[tokio::test]
827 async fn peer_close_returns_stream_closed_error() {
828 let program_id = Pubkey::from([5u8; 32]);
829 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
830 let addr = listener.local_addr().unwrap();
831
832 tokio::spawn(async move {
833 let (mut socket, _) = listener.accept().await.unwrap();
834 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
835 socket.read_exact(&mut start).await.unwrap();
836 });
837
838 let config = SignalClientConfig::new(addr, program_id);
839 let mut stream = SignalStream::connect(config).await.unwrap();
840 let err = stream.next_transaction().await.unwrap_err();
841
842 assert!(matches!(err, SignalClientError::StreamClosed));
843 }
844
845 #[tokio::test]
846 async fn signal_client_reconnects_after_peer_close() {
847 let program_id = Pubkey::from([7u8; 32]);
848 let signature = Signature::from([11u8; 64]);
849 let transaction = test_transaction(program_id, signature);
850 let serialized = bincode::serialize(&transaction).unwrap();
851 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
852 let addr = listener.local_addr().unwrap();
853
854 tokio::spawn(async move {
855 let (mut first, _) = listener.accept().await.unwrap();
856 let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
857 first.read_exact(&mut start).await.unwrap();
858 drop(first);
859
860 let (mut second, _) = listener.accept().await.unwrap();
861 second.read_exact(&mut start).await.unwrap();
862 let frame = encode_tx_frame(45, 102, signature, &serialized);
863 second.write_all(&frame).await.unwrap();
864 });
865
866 let mut client = SignalClient::builder()
867 .addr(addr)
868 .program_id(program_id)
869 .reconnect_backoff(Duration::from_millis(1))
870 .connect()
871 .await
872 .unwrap();
873 let received = time::timeout(Duration::from_secs(2), client.next_raw())
874 .await
875 .unwrap()
876 .unwrap();
877
878 assert_eq!(received.slot, 45);
879 assert_eq!(received.created_at_unix_nanos, 102);
880 assert_eq!(received.signature, signature);
881 }
882
883 fn test_transaction(program_id: Pubkey, signature: Signature) -> VersionedTransaction {
884 let payer = Pubkey::from([1u8; 32]);
885 let instruction = Instruction::new_with_bytes(program_id, &[1, 2, 3], vec![]);
886 let message = Message::new_with_blockhash(&[instruction], Some(&payer), &Hash::default());
887 VersionedTransaction {
888 signatures: vec![signature],
889 message: VersionedMessage::Legacy(message),
890 }
891 }
892
893 fn encode_tx_frame(
894 slot: u64,
895 created_at_unix_nanos: i64,
896 signature: Signature,
897 serialized_transaction: &[u8],
898 ) -> Vec<u8> {
899 let payload_len = protocol::TX_METADATA_LEN + serialized_transaction.len();
900 let body_len = 1 + payload_len;
901 let mut frame = Vec::with_capacity(protocol::FRAME_HEADER_LEN + body_len);
902 frame.extend_from_slice(&(body_len as u32).to_le_bytes());
903 frame.push(protocol::MSG_TX);
904 frame.extend_from_slice(&slot.to_le_bytes());
905 frame.extend_from_slice(&created_at_unix_nanos.to_le_bytes());
906 frame.extend_from_slice(signature.as_ref());
907 frame.extend_from_slice(serialized_transaction);
908 frame
909 }
910}