Skip to main content

signal_client/
lib.rs

1use std::{
2    fmt, io,
3    net::{IpAddr, SocketAddr},
4    ops::Range,
5    str::FromStr,
6    time::Duration,
7};
8
9use bytes::Bytes;
10use rand::{seq::SliceRandom, thread_rng};
11use signal_protocol as protocol;
12use solana_pubkey::Pubkey;
13use solana_signature::Signature;
14use solana_transaction::versioned::VersionedTransaction;
15use thiserror::Error;
16use tokio::{
17    io::{AsyncReadExt, AsyncWriteExt},
18    net::{lookup_host, TcpStream},
19    time,
20};
21
22pub const DEFAULT_SIGNAL_PORT: u16 = 9000;
23pub const DEFAULT_SIGNAL_REGION: SignalRegion = SignalRegion::Ewr;
24
25const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_MAX_FRAME_BODY_LEN: usize = 1024 * 1024;
27const DEFAULT_RECONNECT_BACKOFF: Duration = Duration::from_millis(100);
28const SIGNAL_DNS_SUFFIX: &str = "signals.helius-rpc.com";
29
30pub use protocol::{
31    ProgramFilter, ProgramFilterEntry, ProgramFilterSet, ProgramIdBytes,
32    DEFAULT_PROGRAM_FILTER_SET, PROGRAM_FILTERS,
33};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SignalRegion {
37    Ewr,
38    Ams,
39}
40
41impl SignalRegion {
42    pub fn as_str(self) -> &'static str {
43        match self {
44            Self::Ewr => "ewr",
45            Self::Ams => "ams",
46        }
47    }
48
49    pub fn hostname(self) -> String {
50        format!("{}.{}", self.as_str(), SIGNAL_DNS_SUFFIX)
51    }
52}
53
54impl Default for SignalRegion {
55    fn default() -> Self {
56        DEFAULT_SIGNAL_REGION
57    }
58}
59
60impl fmt::Display for SignalRegion {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.write_str(self.as_str())
63    }
64}
65
66impl FromStr for SignalRegion {
67    type Err = ParseSignalRegionError;
68
69    fn from_str(value: &str) -> Result<Self, Self::Err> {
70        match value.to_ascii_lowercase().as_str() {
71            "ewr" => Ok(Self::Ewr),
72            "ams" => Ok(Self::Ams),
73            _ => Err(ParseSignalRegionError {
74                input: value.to_string(),
75            }),
76        }
77    }
78}
79
80#[derive(Debug, Clone, Error, PartialEq, Eq)]
81#[error("invalid signal region {input:?}; supported regions are ewr and ams")]
82pub struct ParseSignalRegionError {
83    input: String,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct SignalEndpoint {
88    kind: SignalEndpointKind,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92enum SignalEndpointKind {
93    Region(SignalRegion),
94    Addr(SocketAddr),
95    #[cfg(test)]
96    StaticAddrs(Vec<SocketAddr>),
97}
98
99impl SignalEndpoint {
100    pub fn region(region: SignalRegion) -> Self {
101        Self {
102            kind: SignalEndpointKind::Region(region),
103        }
104    }
105
106    pub fn addr(addr: SocketAddr) -> Self {
107        Self {
108            kind: SignalEndpointKind::Addr(addr),
109        }
110    }
111
112    pub fn ip(ip: IpAddr) -> Self {
113        Self::addr(SocketAddr::new(ip, DEFAULT_SIGNAL_PORT))
114    }
115
116    async fn resolve_addrs(&self) -> Result<Vec<SocketAddr>, SignalClientError> {
117        match &self.kind {
118            SignalEndpointKind::Region(region) => {
119                let hostname = region.hostname();
120                let mut addrs = Vec::new();
121                for addr in lookup_host((hostname.as_str(), DEFAULT_SIGNAL_PORT)).await? {
122                    if !addrs.contains(&addr) {
123                        addrs.push(addr);
124                    }
125                }
126                if addrs.is_empty() {
127                    return Err(SignalClientError::NoResolvedAddresses(self.to_string()));
128                }
129                shuffle_addrs(&mut addrs);
130                Ok(addrs)
131            }
132            SignalEndpointKind::Addr(addr) => Ok(vec![*addr]),
133            #[cfg(test)]
134            SignalEndpointKind::StaticAddrs(addrs) => {
135                if addrs.is_empty() {
136                    return Err(SignalClientError::NoResolvedAddresses(self.to_string()));
137                }
138                Ok(addrs.clone())
139            }
140        }
141    }
142}
143
144impl Default for SignalEndpoint {
145    fn default() -> Self {
146        Self::region(DEFAULT_SIGNAL_REGION)
147    }
148}
149
150impl From<SignalRegion> for SignalEndpoint {
151    fn from(region: SignalRegion) -> Self {
152        Self::region(region)
153    }
154}
155
156impl From<SocketAddr> for SignalEndpoint {
157    fn from(addr: SocketAddr) -> Self {
158        Self::addr(addr)
159    }
160}
161
162impl From<IpAddr> for SignalEndpoint {
163    fn from(ip: IpAddr) -> Self {
164        Self::ip(ip)
165    }
166}
167
168impl fmt::Display for SignalEndpoint {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        match &self.kind {
171            SignalEndpointKind::Region(region) => f.write_str(region.as_str()),
172            SignalEndpointKind::Addr(addr) => write!(f, "{addr}"),
173            #[cfg(test)]
174            SignalEndpointKind::StaticAddrs(addrs) => write!(f, "{addrs:?}"),
175        }
176    }
177}
178
179impl FromStr for SignalEndpoint {
180    type Err = ParseSignalEndpointError;
181
182    fn from_str(value: &str) -> Result<Self, Self::Err> {
183        if let Ok(region) = value.parse::<SignalRegion>() {
184            return Ok(Self::region(region));
185        }
186        if let Ok(addr) = value.parse::<SocketAddr>() {
187            return Ok(Self::addr(addr));
188        }
189        if let Ok(ip) = value.parse::<IpAddr>() {
190            return Ok(Self::ip(ip));
191        }
192        Err(ParseSignalEndpointError {
193            input: value.to_string(),
194        })
195    }
196}
197
198#[derive(Debug, Clone, Error, PartialEq, Eq)]
199#[error("invalid signal endpoint {input:?}; use ewr, ams, IP, or IP:port")]
200pub struct ParseSignalEndpointError {
201    input: String,
202}
203
204fn shuffle_addrs(addrs: &mut [SocketAddr]) {
205    addrs.shuffle(&mut thread_rng());
206}
207
208#[derive(Debug, Clone)]
209pub struct SignalClientConfig {
210    pub endpoint: SignalEndpoint,
211    pub program_id: [u8; 32],
212    pub connect_timeout: Duration,
213    pub tcp_nodelay: bool,
214    pub max_frame_body_len: usize,
215}
216
217impl SignalClientConfig {
218    pub fn new(endpoint: impl Into<SignalEndpoint>, program_id: Pubkey) -> Self {
219        Self::new_from_program_id_bytes(endpoint, program_id.to_bytes())
220    }
221
222    pub fn new_for_program_filter(
223        endpoint: impl Into<SignalEndpoint>,
224        program_filter: ProgramFilter,
225    ) -> Self {
226        Self::new_from_program_id_bytes(endpoint, program_filter.bytes())
227    }
228
229    pub fn new_from_program_id_bytes(
230        endpoint: impl Into<SignalEndpoint>,
231        program_id: [u8; 32],
232    ) -> Self {
233        Self {
234            endpoint: endpoint.into(),
235            program_id,
236            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
237            tcp_nodelay: true,
238            max_frame_body_len: DEFAULT_MAX_FRAME_BODY_LEN,
239        }
240    }
241}
242
243#[derive(Debug, Clone)]
244pub struct SignalTransaction {
245    pub slot: u64,
246    pub created_at_unix_nanos: i64,
247    pub signature: Signature,
248    pub transaction: VersionedTransaction,
249}
250
251#[derive(Debug, Clone)]
252pub struct RawSignalTransaction {
253    pub slot: u64,
254    pub created_at_unix_nanos: i64,
255    pub signature: Signature,
256    pub serialized_transaction: Bytes,
257}
258
259impl RawSignalTransaction {
260    pub fn into_transaction(self) -> Result<SignalTransaction, SignalClientError> {
261        let transaction = bincode::deserialize(&self.serialized_transaction)
262            .map_err(SignalClientError::DeserializeTransaction)?;
263        Ok(SignalTransaction {
264            slot: self.slot,
265            created_at_unix_nanos: self.created_at_unix_nanos,
266            signature: self.signature,
267            transaction,
268        })
269    }
270}
271
272#[derive(Debug, Clone, Copy)]
273pub struct BorrowedRawSignalTransaction<'a> {
274    pub slot: u64,
275    pub created_at_unix_nanos: i64,
276    pub signature: Signature,
277    pub serialized_transaction: &'a [u8],
278}
279
280impl BorrowedRawSignalTransaction<'_> {
281    pub fn deserialize_transaction(&self) -> Result<VersionedTransaction, SignalClientError> {
282        bincode::deserialize(self.serialized_transaction)
283            .map_err(SignalClientError::DeserializeTransaction)
284    }
285
286    pub fn to_owned(self) -> RawSignalTransaction {
287        RawSignalTransaction {
288            slot: self.slot,
289            created_at_unix_nanos: self.created_at_unix_nanos,
290            signature: self.signature,
291            serialized_transaction: Bytes::copy_from_slice(self.serialized_transaction),
292        }
293    }
294}
295
296#[derive(Debug, Error)]
297pub enum SignalClientError {
298    #[error("signal client builder requires a program filter or program id")]
299    MissingProgramFilter,
300    #[error("connect timeout after {0:?}")]
301    ConnectTimeout(Duration),
302    #[error("no addresses resolved for signal endpoint {0}")]
303    NoResolvedAddresses(String),
304    #[error("signal stream closed by peer")]
305    StreamClosed,
306    #[error("io error: {0}")]
307    Io(#[from] io::Error),
308    #[error("invalid frame body length {len}; max configured length is {max}")]
309    FrameTooLarge { len: usize, max: usize },
310    #[error("invalid zero-length frame body")]
311    EmptyFrame,
312    #[error("unexpected signal opcode {0:#04x}")]
313    UnexpectedOpcode(u8),
314    #[error("failed to deserialize VersionedTransaction: {0}")]
315    DeserializeTransaction(bincode::Error),
316    #[error("protocol error: {0}")]
317    Protocol(#[from] protocol::WireError),
318}
319
320pub struct SignalStream {
321    stream: TcpStream,
322    max_frame_body_len: usize,
323    header: [u8; protocol::FRAME_HEADER_LEN],
324    read_buf: Vec<u8>,
325    read_pos: usize,
326    payload_range: Range<usize>,
327}
328
329impl SignalStream {
330    pub async fn connect(config: SignalClientConfig) -> Result<Self, SignalClientError> {
331        let addrs = config.endpoint.resolve_addrs().await?;
332        let mut last_error = None;
333        for addr in addrs {
334            match Self::connect_addr(&config, addr).await {
335                Ok(stream) => return Ok(stream),
336                Err(err) => last_error = Some(err),
337            }
338        }
339
340        Err(last_error
341            .unwrap_or_else(|| SignalClientError::NoResolvedAddresses(config.endpoint.to_string())))
342    }
343
344    async fn connect_addr(
345        config: &SignalClientConfig,
346        addr: SocketAddr,
347    ) -> Result<Self, SignalClientError> {
348        let stream = match time::timeout(config.connect_timeout, TcpStream::connect(addr)).await {
349            Ok(result) => result?,
350            Err(_) => return Err(SignalClientError::ConnectTimeout(config.connect_timeout)),
351        };
352        stream.set_nodelay(config.tcp_nodelay)?;
353
354        let mut client = Self {
355            stream,
356            max_frame_body_len: config.max_frame_body_len,
357            header: [0; protocol::FRAME_HEADER_LEN],
358            read_buf: Vec::with_capacity(64 * 1024),
359            read_pos: 0,
360            payload_range: 0..0,
361        };
362        client.write_start_stream(config.program_id).await?;
363        Ok(client)
364    }
365
366    pub async fn next_transaction(&mut self) -> Result<SignalTransaction, SignalClientError> {
367        self.read_next_tx_payload().await?;
368
369        let payload = self.current_payload();
370        let metadata = protocol::parse_tx_payload_metadata(payload)?;
371        let transaction = bincode::deserialize(&payload[protocol::TX_METADATA_LEN..])
372            .map_err(SignalClientError::DeserializeTransaction)?;
373
374        Ok(SignalTransaction {
375            slot: metadata.slot,
376            created_at_unix_nanos: metadata.created_at_ns,
377            signature: Signature::from(metadata.signature),
378            transaction,
379        })
380    }
381
382    pub async fn next_raw(&mut self) -> Result<RawSignalTransaction, SignalClientError> {
383        Ok(self.next_raw_borrowed().await?.to_owned())
384    }
385
386    pub async fn next_raw_borrowed(
387        &mut self,
388    ) -> Result<BorrowedRawSignalTransaction<'_>, SignalClientError> {
389        self.read_next_tx_payload().await?;
390        self.borrow_raw_payload()
391    }
392
393    fn borrow_raw_payload(&self) -> Result<BorrowedRawSignalTransaction<'_>, SignalClientError> {
394        let payload = self.current_payload();
395        let metadata = protocol::parse_tx_payload_metadata(payload)?;
396        Ok(BorrowedRawSignalTransaction {
397            slot: metadata.slot,
398            created_at_unix_nanos: metadata.created_at_ns,
399            signature: Signature::from(metadata.signature),
400            serialized_transaction: &payload[protocol::TX_METADATA_LEN..],
401        })
402    }
403
404    pub fn into_inner(self) -> TcpStream {
405        self.stream
406    }
407
408    async fn write_start_stream(&mut self, program_id: [u8; 32]) -> Result<(), SignalClientError> {
409        let frame = protocol::encode_start_stream_frame(&program_id);
410        self.stream.write_all(&frame).await?;
411        Ok(())
412    }
413
414    async fn read_next_tx_payload(&mut self) -> Result<(), SignalClientError> {
415        self.read_next_frame().await?;
416
417        let opcode = self.header[4];
418        if opcode != protocol::MSG_TX {
419            return Err(SignalClientError::UnexpectedOpcode(opcode));
420        }
421        let payload = self.current_payload();
422        if payload.len() < protocol::TX_METADATA_LEN {
423            protocol::parse_tx_payload_metadata(payload)?;
424        }
425
426        Ok(())
427    }
428
429    async fn read_next_frame(&mut self) -> Result<(), SignalClientError> {
430        self.payload_range = 0..0;
431        self.compact_read_buffer();
432        self.ensure_buffered(protocol::FRAME_HEADER_LEN).await?;
433
434        let header_start = self.read_pos;
435        let header_end = header_start + protocol::FRAME_HEADER_LEN;
436        self.header
437            .copy_from_slice(&self.read_buf[header_start..header_end]);
438
439        let body_len = u32::from_le_bytes(self.header[..4].try_into().unwrap()) as usize;
440        if body_len == 0 {
441            return Err(SignalClientError::EmptyFrame);
442        }
443        if body_len > self.max_frame_body_len {
444            return Err(SignalClientError::FrameTooLarge {
445                len: body_len,
446                max: self.max_frame_body_len,
447            });
448        }
449
450        let payload_len = body_len - 1;
451        let frame_len = protocol::FRAME_HEADER_LEN + payload_len;
452        self.ensure_buffered(frame_len).await?;
453
454        let payload_start = header_end;
455        let payload_end = payload_start + payload_len;
456        self.payload_range = payload_start..payload_end;
457        self.read_pos = payload_end;
458        Ok(())
459    }
460
461    async fn ensure_buffered(&mut self, needed: usize) -> Result<(), SignalClientError> {
462        while self.read_buf.len().saturating_sub(self.read_pos) < needed {
463            let read = self.stream.read_buf(&mut self.read_buf).await?;
464            if read == 0 {
465                return Err(SignalClientError::StreamClosed);
466            }
467        }
468        Ok(())
469    }
470
471    fn compact_read_buffer(&mut self) {
472        if self.read_pos == 0 {
473            return;
474        }
475        if self.read_pos >= self.read_buf.len() {
476            self.read_buf.clear();
477            self.read_pos = 0;
478            return;
479        }
480
481        let remaining = self.read_buf.len() - self.read_pos;
482        self.read_buf.copy_within(self.read_pos.., 0);
483        self.read_buf.truncate(remaining);
484        self.read_pos = 0;
485    }
486
487    fn current_payload(&self) -> &[u8] {
488        &self.read_buf[self.payload_range.clone()]
489    }
490}
491
492#[derive(Debug, Clone)]
493pub struct SignalClientBuilder {
494    endpoint: SignalEndpoint,
495    program_id: Option<[u8; 32]>,
496    connect_timeout: Duration,
497    tcp_nodelay: bool,
498    max_frame_body_len: usize,
499    reconnect: bool,
500    reconnect_backoff: Duration,
501}
502
503impl Default for SignalClientBuilder {
504    fn default() -> Self {
505        Self {
506            endpoint: SignalEndpoint::default(),
507            program_id: None,
508            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
509            tcp_nodelay: true,
510            max_frame_body_len: DEFAULT_MAX_FRAME_BODY_LEN,
511            reconnect: true,
512            reconnect_backoff: DEFAULT_RECONNECT_BACKOFF,
513        }
514    }
515}
516
517impl SignalClientBuilder {
518    pub fn endpoint(mut self, endpoint: impl Into<SignalEndpoint>) -> Self {
519        self.endpoint = endpoint.into();
520        self
521    }
522
523    pub fn region(self, region: SignalRegion) -> Self {
524        self.endpoint(region)
525    }
526
527    pub fn addr(self, addr: SocketAddr) -> Self {
528        self.endpoint(addr)
529    }
530
531    pub fn ip(self, ip: IpAddr) -> Self {
532        self.endpoint(ip)
533    }
534
535    pub fn program_id(mut self, program_id: Pubkey) -> Self {
536        self.program_id = Some(program_id.to_bytes());
537        self
538    }
539
540    pub fn program_filter(mut self, program_filter: ProgramFilter) -> Self {
541        self.program_id = Some(program_filter.bytes());
542        self
543    }
544
545    pub fn program_id_bytes(mut self, program_id: [u8; 32]) -> Self {
546        self.program_id = Some(program_id);
547        self
548    }
549
550    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
551        self.connect_timeout = timeout;
552        self
553    }
554
555    pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
556        self.tcp_nodelay = enabled;
557        self
558    }
559
560    pub fn max_frame_body_len(mut self, len: usize) -> Self {
561        self.max_frame_body_len = len;
562        self
563    }
564
565    pub fn reconnect(mut self, enabled: bool) -> Self {
566        self.reconnect = enabled;
567        self
568    }
569
570    pub fn reconnect_backoff(mut self, backoff: Duration) -> Self {
571        self.reconnect_backoff = backoff;
572        self
573    }
574
575    pub async fn connect(self) -> Result<SignalClient, SignalClientError> {
576        SignalClient::connect(self).await
577    }
578
579    fn stream_config(&self) -> Result<SignalClientConfig, SignalClientError> {
580        let program_id = self
581            .program_id
582            .ok_or(SignalClientError::MissingProgramFilter)?;
583        Ok(SignalClientConfig {
584            endpoint: self.endpoint.clone(),
585            program_id,
586            connect_timeout: self.connect_timeout,
587            tcp_nodelay: self.tcp_nodelay,
588            max_frame_body_len: self.max_frame_body_len,
589        })
590    }
591}
592
593pub struct SignalClient {
594    stream_config: SignalClientConfig,
595    reconnect: bool,
596    reconnect_backoff: Duration,
597    stream: SignalStream,
598}
599
600impl SignalClient {
601    pub fn builder() -> SignalClientBuilder {
602        SignalClientBuilder::default()
603    }
604
605    pub async fn connect(builder: SignalClientBuilder) -> Result<Self, SignalClientError> {
606        let stream_config = builder.stream_config()?;
607        let stream = SignalStream::connect(stream_config.clone()).await?;
608        Ok(Self {
609            stream_config,
610            reconnect: builder.reconnect,
611            reconnect_backoff: builder.reconnect_backoff,
612            stream,
613        })
614    }
615
616    pub async fn next_raw(&mut self) -> Result<RawSignalTransaction, SignalClientError> {
617        loop {
618            match self.stream.next_raw().await {
619                Ok(tx) => return Ok(tx),
620                Err(err) if self.should_reconnect(&err) => self.reconnect().await?,
621                Err(err) => return Err(err),
622            }
623        }
624    }
625
626    pub async fn next_transaction(&mut self) -> Result<SignalTransaction, SignalClientError> {
627        self.next_raw().await?.into_transaction()
628    }
629
630    async fn reconnect(&mut self) -> Result<(), SignalClientError> {
631        loop {
632            time::sleep(self.reconnect_backoff).await;
633            match SignalStream::connect(self.stream_config.clone()).await {
634                Ok(stream) => {
635                    self.stream = stream;
636                    return Ok(());
637                }
638                Err(err) if self.should_reconnect(&err) => continue,
639                Err(err) => return Err(err),
640            }
641        }
642    }
643
644    fn should_reconnect(&self, err: &SignalClientError) -> bool {
645        self.reconnect
646            && matches!(
647                err,
648                SignalClientError::StreamClosed
649                    | SignalClientError::Io(_)
650                    | SignalClientError::ConnectTimeout(_)
651                    | SignalClientError::NoResolvedAddresses(_)
652            )
653    }
654
655    pub fn into_stream(self) -> SignalStream {
656        self.stream
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use solana_hash::Hash;
664    use solana_instruction::Instruction;
665    use solana_message::{Message, VersionedMessage};
666    use tokio::net::TcpListener;
667
668    #[test]
669    fn parses_region_and_endpoint_values() {
670        assert_eq!("EWR".parse::<SignalRegion>().unwrap(), SignalRegion::Ewr);
671        assert_eq!("ams".parse::<SignalRegion>().unwrap(), SignalRegion::Ams);
672        assert_eq!(SignalRegion::Ewr.hostname(), "ewr.signals.helius-rpc.com");
673
674        assert_eq!(
675            "ewr".parse::<SignalEndpoint>().unwrap(),
676            SignalEndpoint::region(SignalRegion::Ewr)
677        );
678        assert_eq!(
679            "127.0.0.1".parse::<SignalEndpoint>().unwrap(),
680            SignalEndpoint::ip("127.0.0.1".parse::<IpAddr>().unwrap())
681        );
682        assert_eq!(
683            "127.0.0.1:9001".parse::<SignalEndpoint>().unwrap(),
684            SignalEndpoint::addr("127.0.0.1:9001".parse::<SocketAddr>().unwrap())
685        );
686    }
687
688    #[tokio::test]
689    async fn builder_requires_program_filter() {
690        let result = SignalClient::builder()
691            .addr("127.0.0.1:9000".parse::<SocketAddr>().unwrap())
692            .connect()
693            .await;
694
695        assert!(matches!(
696            result,
697            Err(SignalClientError::MissingProgramFilter)
698        ));
699    }
700
701    #[tokio::test]
702    async fn connects_and_decodes_transaction() {
703        let program_id = Pubkey::from([2u8; 32]);
704        let signature = Signature::from([7u8; 64]);
705        let transaction = test_transaction(program_id, signature);
706        let serialized = bincode::serialize(&transaction).unwrap();
707        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
708        let addr = listener.local_addr().unwrap();
709
710        let server = tokio::spawn(async move {
711            let (mut socket, _) = listener.accept().await.unwrap();
712            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
713            socket.read_exact(&mut start).await.unwrap();
714            assert_eq!(
715                u32::from_le_bytes(start[..4].try_into().unwrap()),
716                protocol::START_STREAM_FRAME_BODY_LEN as u32
717            );
718            assert_eq!(start[4], protocol::MSG_START_STREAM);
719            assert_eq!(&start[5..], &program_id.to_bytes());
720
721            let frame = encode_tx_frame(42, 99, signature, &serialized);
722            socket.write_all(&frame).await.unwrap();
723        });
724
725        let config = SignalClientConfig::new(addr, program_id);
726        let mut stream = SignalStream::connect(config).await.unwrap();
727        let received = stream.next_transaction().await.unwrap();
728
729        assert_eq!(received.slot, 42);
730        assert_eq!(received.created_at_unix_nanos, 99);
731        assert_eq!(received.signature, signature);
732        assert_eq!(received.transaction.signatures[0], signature);
733        server.await.unwrap();
734    }
735
736    #[tokio::test]
737    async fn connect_falls_back_to_next_available_addr() {
738        let program_id = Pubkey::from([4u8; 32]);
739        let signature = Signature::from([10u8; 64]);
740        let transaction = test_transaction(program_id, signature);
741        let serialized = bincode::serialize(&transaction).unwrap();
742
743        let bad_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
744        let bad_addr = bad_listener.local_addr().unwrap();
745        drop(bad_listener);
746
747        let good_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
748        let good_addr = good_listener.local_addr().unwrap();
749        tokio::spawn(async move {
750            let (mut socket, _) = good_listener.accept().await.unwrap();
751            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
752            socket.read_exact(&mut start).await.unwrap();
753            let frame = encode_tx_frame(46, 103, signature, &serialized);
754            socket.write_all(&frame).await.unwrap();
755        });
756
757        let endpoint = SignalEndpoint {
758            kind: SignalEndpointKind::StaticAddrs(vec![bad_addr, good_addr]),
759        };
760        let mut config = SignalClientConfig::new(endpoint, program_id);
761        config.connect_timeout = Duration::from_secs(1);
762        let mut stream = SignalStream::connect(config).await.unwrap();
763        let received = stream.next_raw().await.unwrap();
764
765        assert_eq!(received.slot, 46);
766        assert_eq!(received.created_at_unix_nanos, 103);
767        assert_eq!(received.signature, signature);
768    }
769
770    #[tokio::test]
771    async fn raw_mode_preserves_serialized_transaction_bytes() {
772        let program_id = Pubkey::from([3u8; 32]);
773        let signature = Signature::from([8u8; 64]);
774        let transaction = test_transaction(program_id, signature);
775        let serialized = bincode::serialize(&transaction).unwrap();
776        let server_serialized = serialized.clone();
777        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
778        let addr = listener.local_addr().unwrap();
779
780        tokio::spawn(async move {
781            let (mut socket, _) = listener.accept().await.unwrap();
782            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
783            socket.read_exact(&mut start).await.unwrap();
784            let frame = encode_tx_frame(43, 100, signature, &server_serialized);
785            socket.write_all(&frame).await.unwrap();
786        });
787
788        let config = SignalClientConfig::new(addr, program_id);
789        let mut stream = SignalStream::connect(config).await.unwrap();
790        let received = stream.next_raw().await.unwrap();
791
792        assert_eq!(received.slot, 43);
793        assert_eq!(received.created_at_unix_nanos, 100);
794        assert_eq!(received.signature, signature);
795        assert_eq!(&received.serialized_transaction[..], serialized.as_slice());
796    }
797
798    #[tokio::test]
799    async fn borrowed_raw_mode_avoids_copying_serialized_transaction() {
800        let program_id = Pubkey::from([6u8; 32]);
801        let signature = Signature::from([9u8; 64]);
802        let transaction = test_transaction(program_id, signature);
803        let serialized = bincode::serialize(&transaction).unwrap();
804        let server_serialized = serialized.clone();
805        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
806        let addr = listener.local_addr().unwrap();
807
808        tokio::spawn(async move {
809            let (mut socket, _) = listener.accept().await.unwrap();
810            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
811            socket.read_exact(&mut start).await.unwrap();
812            let frame = encode_tx_frame(44, 101, signature, &server_serialized);
813            socket.write_all(&frame).await.unwrap();
814        });
815
816        let config = SignalClientConfig::new(addr, program_id);
817        let mut stream = SignalStream::connect(config).await.unwrap();
818        let received = stream.next_raw_borrowed().await.unwrap();
819
820        assert_eq!(received.slot, 44);
821        assert_eq!(received.created_at_unix_nanos, 101);
822        assert_eq!(received.signature, signature);
823        assert_eq!(received.serialized_transaction, serialized.as_slice());
824    }
825
826    #[tokio::test]
827    async fn peer_close_returns_stream_closed_error() {
828        let program_id = Pubkey::from([5u8; 32]);
829        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
830        let addr = listener.local_addr().unwrap();
831
832        tokio::spawn(async move {
833            let (mut socket, _) = listener.accept().await.unwrap();
834            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
835            socket.read_exact(&mut start).await.unwrap();
836        });
837
838        let config = SignalClientConfig::new(addr, program_id);
839        let mut stream = SignalStream::connect(config).await.unwrap();
840        let err = stream.next_transaction().await.unwrap_err();
841
842        assert!(matches!(err, SignalClientError::StreamClosed));
843    }
844
845    #[tokio::test]
846    async fn signal_client_reconnects_after_peer_close() {
847        let program_id = Pubkey::from([7u8; 32]);
848        let signature = Signature::from([11u8; 64]);
849        let transaction = test_transaction(program_id, signature);
850        let serialized = bincode::serialize(&transaction).unwrap();
851        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
852        let addr = listener.local_addr().unwrap();
853
854        tokio::spawn(async move {
855            let (mut first, _) = listener.accept().await.unwrap();
856            let mut start = [0u8; protocol::START_STREAM_FRAME_LEN];
857            first.read_exact(&mut start).await.unwrap();
858            drop(first);
859
860            let (mut second, _) = listener.accept().await.unwrap();
861            second.read_exact(&mut start).await.unwrap();
862            let frame = encode_tx_frame(45, 102, signature, &serialized);
863            second.write_all(&frame).await.unwrap();
864        });
865
866        let mut client = SignalClient::builder()
867            .addr(addr)
868            .program_id(program_id)
869            .reconnect_backoff(Duration::from_millis(1))
870            .connect()
871            .await
872            .unwrap();
873        let received = time::timeout(Duration::from_secs(2), client.next_raw())
874            .await
875            .unwrap()
876            .unwrap();
877
878        assert_eq!(received.slot, 45);
879        assert_eq!(received.created_at_unix_nanos, 102);
880        assert_eq!(received.signature, signature);
881    }
882
883    fn test_transaction(program_id: Pubkey, signature: Signature) -> VersionedTransaction {
884        let payer = Pubkey::from([1u8; 32]);
885        let instruction = Instruction::new_with_bytes(program_id, &[1, 2, 3], vec![]);
886        let message = Message::new_with_blockhash(&[instruction], Some(&payer), &Hash::default());
887        VersionedTransaction {
888            signatures: vec![signature],
889            message: VersionedMessage::Legacy(message),
890        }
891    }
892
893    fn encode_tx_frame(
894        slot: u64,
895        created_at_unix_nanos: i64,
896        signature: Signature,
897        serialized_transaction: &[u8],
898    ) -> Vec<u8> {
899        let payload_len = protocol::TX_METADATA_LEN + serialized_transaction.len();
900        let body_len = 1 + payload_len;
901        let mut frame = Vec::with_capacity(protocol::FRAME_HEADER_LEN + body_len);
902        frame.extend_from_slice(&(body_len as u32).to_le_bytes());
903        frame.push(protocol::MSG_TX);
904        frame.extend_from_slice(&slot.to_le_bytes());
905        frame.extend_from_slice(&created_at_unix_nanos.to_le_bytes());
906        frame.extend_from_slice(signature.as_ref());
907        frame.extend_from_slice(serialized_transaction);
908        frame
909    }
910}