Skip to main content

electrum_client/
raw_client.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Raw client
4//!
5//! This module contains the definition of the raw client that wraps the transport method
6
7use std::borrow::Borrow;
8use std::collections::{BTreeMap, HashMap, VecDeque};
9use std::io::{BufRead, BufReader, Read, Write};
10use std::mem::drop;
11use std::net::{TcpStream, ToSocketAddrs};
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::mpsc::{channel, Receiver, Sender};
14use std::sync::{Arc, Mutex, TryLockError};
15use std::time::Duration;
16
17#[allow(unused_imports)]
18use log::{debug, error, info, trace, warn};
19
20use bitcoin::consensus::encode::deserialize;
21use bitcoin::hex::{DisplayHex, FromHex};
22use bitcoin::{Script, Txid};
23
24#[cfg(feature = "openssl")]
25use openssl::ssl::{SslConnector, SslMethod, SslStream, SslVerifyMode};
26
27#[cfg(any(feature = "rustls", feature = "rustls-ring"))]
28#[allow(unused_imports)]
29use rustls::{
30    pki_types::ServerName,
31    pki_types::{Der, TrustAnchor},
32    ClientConfig, ClientConnection, RootCertStore, StreamOwned,
33};
34
35#[cfg(feature = "proxy")]
36use crate::socks::{Socks5Stream, TargetAddr, ToTargetAddr};
37
38use crate::stream::ClonableStream;
39
40use crate::api::ElectrumApi;
41use crate::batch::Batch;
42use crate::config::AuthProvider;
43use crate::types::*;
44
45/// Client name sent to the server during protocol version negotiation.
46pub const CLIENT_NAME: &str = "";
47
48/// Minimum protocol version supported by this client.
49pub const PROTOCOL_VERSION_MIN: &str = "1.4";
50
51/// Maximum protocol version supported by this client.
52pub const PROTOCOL_VERSION_MAX: &str = "1.6";
53
54/// Checks if a protocol version string is at least the specified major.minor version.
55fn is_protocol_version_at_least(version: &str, major: u32, minor: u32) -> bool {
56    let mut parts = version.split('.');
57    let v_major = parts.next().and_then(|s| s.parse::<u32>().ok());
58    let v_minor = parts.next().and_then(|s| s.parse::<u32>().ok());
59    match (v_major, v_minor) {
60        (Some(v_major), Some(v_minor)) => v_major > major || (v_major == major && v_minor >= minor),
61        _ => false,
62    }
63}
64
65macro_rules! impl_batch_call {
66    ( $self:expr, $data:expr, $call:ident ) => {{
67        impl_batch_call!($self, $data, $call, )
68    }};
69
70    ( $self:expr, $data:expr, $call:ident, apply_deref ) => {{
71        impl_batch_call!($self, $data, $call, *)
72    }};
73
74    ( $self:expr, $data:expr, $call:ident, $($apply_deref:tt)? ) => {{
75        let mut batch = Batch::default();
76        for i in $data {
77            batch.$call($($apply_deref)* i.borrow());
78        }
79
80        let resp = $self.batch_call(&batch)?;
81        let mut answer = Vec::new();
82
83        for x in resp {
84            answer.push(serde_json::from_value(x)?);
85        }
86
87        Ok(answer)
88    }};
89}
90
91/// A trait for [`ToSocketAddrs`](https://doc.rust-lang.org/std/net/trait.ToSocketAddrs.html) that
92/// can also be turned into a domain. Used when an SSL client needs to validate the server's
93/// certificate.
94pub trait ToSocketAddrsDomain: ToSocketAddrs {
95    /// Returns the domain, if present
96    fn domain(&self) -> Option<&str> {
97        None
98    }
99}
100
101impl ToSocketAddrsDomain for &str {
102    fn domain(&self) -> Option<&str> {
103        self.split(':').next()
104    }
105}
106
107impl ToSocketAddrsDomain for (&str, u16) {
108    fn domain(&self) -> Option<&str> {
109        self.0.domain()
110    }
111}
112
113#[cfg(feature = "proxy")]
114impl ToSocketAddrsDomain for TargetAddr {
115    fn domain(&self) -> Option<&str> {
116        match self {
117            TargetAddr::Ip(_) => None,
118            TargetAddr::Domain(domain, _) => Some(domain.as_str()),
119        }
120    }
121}
122
123macro_rules! impl_to_socket_addrs_domain {
124    ( $ty:ty ) => {
125        impl ToSocketAddrsDomain for $ty {}
126    };
127}
128
129impl_to_socket_addrs_domain!(std::net::SocketAddr);
130impl_to_socket_addrs_domain!(std::net::SocketAddrV4);
131impl_to_socket_addrs_domain!(std::net::SocketAddrV6);
132impl_to_socket_addrs_domain!((std::net::IpAddr, u16));
133impl_to_socket_addrs_domain!((std::net::Ipv4Addr, u16));
134impl_to_socket_addrs_domain!((std::net::Ipv6Addr, u16));
135
136/// Instance of an Electrum client
137///
138/// A [`RawClient`] maintains a constant connection with an Electrum server and exposes methods to
139/// interact with it. It can also subscribe and receive notifications from the server about new
140/// blocks or activity on a specific *scriptPubKey*.
141///
142/// The [`RawClient`] is modeled in such a way that allows the external caller to have full control over
143/// its functionality: no threads or tasks are spawned internally to monitor the state of the
144/// connection.
145///
146/// More transport methods can be used by manually creating an instance of this struct with an
147/// arbitrary `S` type.
148pub struct RawClient<S>
149where
150    S: Read + Write,
151{
152    stream: Mutex<ClonableStream<S>>,
153    buf_reader: Mutex<BufReader<ClonableStream<S>>>,
154
155    last_id: AtomicUsize,
156    waiting_map: Mutex<HashMap<usize, Sender<ChannelMessage>>>,
157
158    headers: Mutex<VecDeque<RawHeaderNotification>>,
159    script_notifications: Mutex<HashMap<ScriptHash, VecDeque<ScriptStatus>>>,
160
161    /// The protocol version negotiated with the server via `server.version`.
162    protocol_version: Mutex<Option<String>>,
163
164    /// Optional authorization provider for dynamic token injection (e.g., JWT).
165    auth_provider: Option<AuthProvider>,
166
167    calls: AtomicUsize,
168}
169
170// Custom Debug impl because AuthProvider doesn't implement Debug
171impl<S> std::fmt::Debug for RawClient<S>
172where
173    S: Read + Write,
174{
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("RawClient")
177            .field("stream", &"<stream>")
178            .field("buf_reader", &"<buf_reader>")
179            .field("last_id", &self.last_id)
180            .field("waiting_map", &self.waiting_map)
181            .field("headers", &self.headers)
182            .field("script_notifications", &self.script_notifications)
183            .field(
184                "auth_provider",
185                &self.auth_provider.as_ref().map(|_| "<provider>"),
186            )
187            .finish()
188    }
189}
190
191impl<S> From<S> for RawClient<S>
192where
193    S: Read + Write,
194{
195    fn from(stream: S) -> Self {
196        let stream: ClonableStream<_> = stream.into();
197
198        Self {
199            buf_reader: Mutex::new(BufReader::new(stream.clone())),
200            stream: Mutex::new(stream),
201
202            last_id: AtomicUsize::new(0),
203            waiting_map: Mutex::new(HashMap::new()),
204
205            headers: Mutex::new(VecDeque::new()),
206            script_notifications: Mutex::new(HashMap::new()),
207
208            protocol_version: Mutex::new(None),
209
210            auth_provider: None,
211
212            calls: AtomicUsize::new(0),
213        }
214    }
215}
216
217/// Transport type used to establish a plaintext TCP connection with the server
218pub type ElectrumPlaintextStream = TcpStream;
219impl RawClient<ElectrumPlaintextStream> {
220    /// Creates a new plaintext client and tries to connect to `socket_addr`.
221    ///
222    /// Automatically negotiates the protocol version with the server using
223    /// `server.version` as required by the Electrum protocol.
224    pub fn new<A: ToSocketAddrs>(
225        socket_addrs: A,
226        timeout: Option<Duration>,
227        auth_provider: Option<AuthProvider>,
228    ) -> Result<Self, Error> {
229        let stream = match timeout {
230            Some(timeout) => {
231                let stream = connect_with_total_timeout(socket_addrs, timeout)?;
232                stream.set_read_timeout(Some(timeout))?;
233                stream.set_write_timeout(Some(timeout))?;
234                stream
235            }
236            None => TcpStream::connect(socket_addrs)?,
237        };
238
239        let client = Self::from(stream)
240            .with_auth(auth_provider)
241            .negotiate_protocol_version()?;
242
243        Ok(client)
244    }
245}
246
247fn connect_with_total_timeout<A: ToSocketAddrs>(
248    socket_addrs: A,
249    mut timeout: Duration,
250) -> Result<TcpStream, Error> {
251    // Use the same algorithm as curl: 1/2 on the first host, 1/4 on the second one, etc.
252    // https://curl.se/mail/lib-2014-11/0164.html
253
254    let mut errors = Vec::new();
255
256    let addrs = socket_addrs
257        .to_socket_addrs()?
258        .enumerate()
259        .collect::<Vec<_>>();
260    for (index, addr) in &addrs {
261        if *index < addrs.len() - 1 {
262            timeout = timeout.div_f32(2.0);
263        }
264
265        info!(
266            "Trying to connect to {} (attempt {}/{}) with timeout {:?}",
267            addr,
268            index + 1,
269            addrs.len(),
270            timeout
271        );
272        match TcpStream::connect_timeout(addr, timeout) {
273            Ok(socket) => return Ok(socket),
274            Err(e) => {
275                warn!("Connection error: {:?}", e);
276                errors.push(e.into());
277            }
278        }
279    }
280
281    Err(Error::AllAttemptsErrored(errors))
282}
283
284#[cfg(feature = "openssl")]
285/// Transport type used to establish an OpenSSL TLS encrypted/authenticated connection with the server
286pub type ElectrumSslStream = SslStream<TcpStream>;
287#[cfg(feature = "openssl")]
288impl RawClient<ElectrumSslStream> {
289    /// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if
290    /// `validate_domain` is `true`, validate the server's certificate.
291    pub fn new_ssl<A: ToSocketAddrsDomain + Clone>(
292        socket_addrs: A,
293        validate_domain: bool,
294        timeout: Option<Duration>,
295        auth_provider: Option<AuthProvider>,
296    ) -> Result<Self, Error> {
297        debug!(
298            "new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}",
299            socket_addrs.domain(),
300            validate_domain,
301            timeout
302        );
303        if validate_domain {
304            socket_addrs.domain().ok_or(Error::MissingDomain)?;
305        }
306        match timeout {
307            Some(timeout) => {
308                let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?;
309                stream.set_read_timeout(Some(timeout))?;
310                stream.set_write_timeout(Some(timeout))?;
311                Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
312            }
313            None => {
314                let stream = TcpStream::connect(socket_addrs.clone())?;
315                Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
316            }
317        }
318    }
319
320    /// Create a new SSL client using an existing TcpStream
321    pub fn new_ssl_from_stream<A: ToSocketAddrsDomain>(
322        socket_addrs: A,
323        validate_domain: bool,
324        stream: TcpStream,
325        auth_provider: Option<AuthProvider>,
326    ) -> Result<Self, Error> {
327        let mut builder =
328            SslConnector::builder(SslMethod::tls()).map_err(Error::InvalidSslMethod)?;
329
330        // TODO: support for certificate pinning
331        if validate_domain {
332            socket_addrs.domain().ok_or(Error::MissingDomain)?;
333        } else {
334            builder.set_verify(SslVerifyMode::NONE);
335        }
336        let connector = builder.build();
337
338        let domain = socket_addrs.domain().unwrap_or("NONE").to_string();
339
340        let stream = connector
341            .connect(&domain, stream)
342            .map_err(Error::SslHandshakeError)?;
343
344        let client = Self::from(stream)
345            .with_auth(auth_provider)
346            .negotiate_protocol_version()?;
347
348        Ok(client)
349    }
350}
351
352#[cfg(any(feature = "rustls", feature = "rustls-ring"))]
353#[allow(unused)]
354mod danger {
355    use crate::raw_client::ServerName;
356    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified};
357    use rustls::crypto::CryptoProvider;
358    use rustls::pki_types::{CertificateDer, UnixTime};
359    use rustls::DigitallySignedStruct;
360
361    #[derive(Debug)]
362    pub struct NoCertificateVerification(CryptoProvider);
363
364    impl NoCertificateVerification {
365        pub fn new(provider: CryptoProvider) -> Self {
366            Self(provider)
367        }
368    }
369
370    impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
371        fn verify_server_cert(
372            &self,
373            _end_entity: &CertificateDer<'_>,
374            _intermediates: &[CertificateDer<'_>],
375            _server_name: &ServerName<'_>,
376            _ocsp: &[u8],
377            _now: UnixTime,
378        ) -> Result<ServerCertVerified, rustls::Error> {
379            Ok(ServerCertVerified::assertion())
380        }
381
382        fn verify_tls12_signature(
383            &self,
384            _message: &[u8],
385            _cert: &CertificateDer<'_>,
386            _dss: &DigitallySignedStruct,
387        ) -> Result<HandshakeSignatureValid, rustls::Error> {
388            Ok(HandshakeSignatureValid::assertion())
389        }
390
391        fn verify_tls13_signature(
392            &self,
393            _message: &[u8],
394            _cert: &CertificateDer<'_>,
395            _dss: &DigitallySignedStruct,
396        ) -> Result<HandshakeSignatureValid, rustls::Error> {
397            Ok(HandshakeSignatureValid::assertion())
398        }
399
400        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
401            self.0.signature_verification_algorithms.supported_schemes()
402        }
403    }
404}
405
406#[cfg(all(
407    any(feature = "rustls", feature = "rustls-ring"),
408    not(feature = "openssl")
409))]
410/// Transport type used to establish a Rustls TLS encrypted/authenticated connection with the server
411pub type ElectrumSslStream = StreamOwned<ClientConnection, TcpStream>;
412#[cfg(all(
413    any(feature = "rustls", feature = "rustls-ring"),
414    not(feature = "openssl")
415))]
416impl RawClient<ElectrumSslStream> {
417    /// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if
418    /// `validate_domain` is `true`, validate the server's certificate.
419    pub fn new_ssl<A: ToSocketAddrsDomain + Clone>(
420        socket_addrs: A,
421        validate_domain: bool,
422        timeout: Option<Duration>,
423        auth_provider: Option<AuthProvider>,
424    ) -> Result<Self, Error> {
425        debug!(
426            "new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}",
427            socket_addrs.domain(),
428            validate_domain,
429            timeout
430        );
431
432        if validate_domain {
433            socket_addrs.domain().ok_or(Error::MissingDomain)?;
434        }
435
436        match timeout {
437            Some(timeout) => {
438                let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?;
439                stream.set_read_timeout(Some(timeout))?;
440                stream.set_write_timeout(Some(timeout))?;
441                Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
442            }
443            None => {
444                let stream = TcpStream::connect(socket_addrs.clone())?;
445                Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
446            }
447        }
448    }
449
450    /// Create a new SSL client using an existing TcpStream
451    pub fn new_ssl_from_stream<A: ToSocketAddrsDomain>(
452        socket_addr: A,
453        validate_domain: bool,
454        tcp_stream: TcpStream,
455        auth_provider: Option<AuthProvider>,
456    ) -> Result<Self, Error> {
457        use std::convert::TryFrom;
458
459        if rustls::crypto::CryptoProvider::get_default().is_none() {
460            // We install a crypto provider depending on the set feature.
461            #[cfg(all(feature = "rustls", not(feature = "rustls-ring")))]
462            rustls::crypto::CryptoProvider::install_default(
463                rustls::crypto::aws_lc_rs::default_provider(),
464            )
465            .map_err(|_| {
466                Error::CouldNotCreateConnection(rustls::Error::General(
467                    "Failed to install CryptoProvider".to_string(),
468                ))
469            })?;
470
471            #[cfg(feature = "rustls-ring")]
472            rustls::crypto::CryptoProvider::install_default(
473                rustls::crypto::ring::default_provider(),
474            )
475            .map_err(|_| {
476                Error::CouldNotCreateConnection(rustls::Error::General(
477                    "Failed to install CryptoProvider".to_string(),
478                ))
479            })?;
480        }
481
482        let builder = ClientConfig::builder();
483
484        let config = if validate_domain {
485            socket_addr.domain().ok_or(Error::MissingDomain)?;
486
487            let store = webpki_roots::TLS_SERVER_ROOTS
488                .iter()
489                .map(|t| TrustAnchor {
490                    subject: Der::from_slice(t.subject),
491                    subject_public_key_info: Der::from_slice(t.spki),
492                    name_constraints: t.name_constraints.map(Der::from_slice),
493                })
494                .collect::<RootCertStore>();
495
496            // TODO: cert pinning
497            builder.with_root_certificates(store).with_no_client_auth()
498        } else {
499            builder
500                .dangerous()
501                .with_custom_certificate_verifier(std::sync::Arc::new(
502                    #[cfg(all(feature = "rustls", not(feature = "rustls-ring")))]
503                    danger::NoCertificateVerification::new(rustls::crypto::aws_lc_rs::default_provider()),
504                    #[cfg(feature = "rustls-ring")]
505                    danger::NoCertificateVerification::new(rustls::crypto::ring::default_provider()),
506                ))
507                .with_no_client_auth()
508        };
509
510        let domain = socket_addr.domain().unwrap_or("NONE").to_string();
511        let session = ClientConnection::new(
512            std::sync::Arc::new(config),
513            ServerName::try_from(domain.clone())
514                .map_err(|_| Error::InvalidDNSNameError(domain.clone()))?,
515        )
516        .map_err(Error::CouldNotCreateConnection)?;
517        let stream = StreamOwned::new(session, tcp_stream);
518
519        let client = Self::from(stream)
520            .with_auth(auth_provider)
521            .negotiate_protocol_version()?;
522
523        Ok(client)
524    }
525}
526
527#[cfg(feature = "proxy")]
528/// Transport type used to establish a connection to a server through a socks proxy
529pub type ElectrumProxyStream = Socks5Stream;
530#[cfg(feature = "proxy")]
531impl RawClient<ElectrumProxyStream> {
532    /// Creates a new socks client and tries to connect to `target_addr` using `proxy_addr` as a
533    /// socks proxy server. The DNS resolution of `target_addr`, if required, is done
534    /// through the proxy. This allows to specify, for instance, `.onion` addresses.
535    pub fn new_proxy<T: ToTargetAddr>(
536        target_addr: T,
537        proxy: &crate::Socks5Config,
538        timeout: Option<Duration>,
539        auth_provider: Option<AuthProvider>,
540    ) -> Result<Self, Error> {
541        let mut stream = match proxy.credentials.as_ref() {
542            Some(cred) => Socks5Stream::connect_with_password(
543                &proxy.addr,
544                target_addr,
545                &cred.username,
546                &cred.password,
547                timeout,
548            )?,
549            None => Socks5Stream::connect(&proxy.addr, target_addr, timeout)?,
550        };
551        stream.get_mut().set_read_timeout(timeout)?;
552        stream.get_mut().set_write_timeout(timeout)?;
553
554        let client = Self::from(stream)
555            .with_auth(auth_provider)
556            .negotiate_protocol_version()?;
557
558        Ok(client)
559    }
560
561    #[cfg(all(
562        any(feature = "openssl", feature = "rustls", feature = "rustls-ring",),
563        feature = "proxy",
564    ))]
565    /// Creates a new TLS client that connects to `target_addr` using `proxy_addr` as a socks proxy
566    /// server. The DNS resolution of `target_addr`, if required, is done through the proxy. This
567    /// allows to specify, for instance, `.onion` addresses.
568    pub fn new_proxy_ssl<T: ToTargetAddr>(
569        target_addr: T,
570        validate_domain: bool,
571        proxy: &crate::Socks5Config,
572        timeout: Option<Duration>,
573        auth_provider: Option<AuthProvider>,
574    ) -> Result<RawClient<ElectrumSslStream>, Error> {
575        let target = target_addr.to_target_addr()?;
576
577        let mut stream = match proxy.credentials.as_ref() {
578            Some(cred) => Socks5Stream::connect_with_password(
579                &proxy.addr,
580                target_addr,
581                &cred.username,
582                &cred.password,
583                timeout,
584            )?,
585            None => Socks5Stream::connect(&proxy.addr, target.clone(), timeout)?,
586        };
587
588        stream.get_mut().set_read_timeout(timeout)?;
589        stream.get_mut().set_write_timeout(timeout)?;
590
591        RawClient::new_ssl_from_stream(target, validate_domain, stream.into_inner(), auth_provider)
592    }
593}
594
595#[derive(Debug)]
596enum ChannelMessage {
597    Response(serde_json::Value),
598    WakeUp,
599    Error(Arc<std::io::Error>),
600}
601
602impl<S: Read + Write> RawClient<S> {
603    // TODO: to enable this we have to find a way to allow concurrent read and writes to the
604    // underlying transport struct. This can be done pretty easily for TcpStream because it can be
605    // split into a "read" and a "write" object, but it's not as trivial for other types. Without
606    // such thing, this causes a deadlock, because the reader thread takes a lock on the
607    // `ClonableStream` before other threads can send a request to the server. They will block
608    // waiting for the reader to release the mutex, but this will never happen because the server
609    // didn't receive any request, so it has nothing to send back.
610    //
611    // pub fn reader_thread(&self) -> Result<(), Error> {
612    //     self._reader_thread(None).map(|_| ())
613    // }
614
615    /// Sets the [`AuthProvider`] for this client, enabling authentication on all
616    /// outgoing RPC requests.
617    ///
618    /// The `auth_provider` is a callback invoked before each request, allowing
619    /// dynamic token strategies such as automatic JWT refresh without
620    /// reconnecting the client. Passing `None` or not calling this method
621    /// disables authentication.
622    ///
623    /// # Notes
624    ///
625    /// This method should be called **before** [`RawClient::negotiate_protocol_version`],
626    /// as the initial `server.version` handshake also requires authentication
627    /// on protected servers.
628    fn with_auth(mut self, auth_provider: Option<AuthProvider>) -> Self {
629        self.auth_provider = auth_provider;
630        self
631    }
632
633    /// Negotiates the Electrum protocol version with the Electrum server.
634    ///
635    /// This sends `server.version` as the first message and stores the negotiated
636    /// protocol version.
637    ///
638    /// As of Electrum Protocol v1.6 it's a mandatory step, see:
639    /// <https://electrum-protocol.readthedocs.io/en/latest/protocol-changes.html#version-1-6>
640    ///
641    /// [`ClientType`]: crate::ClientType
642    fn negotiate_protocol_version(self) -> Result<Self, Error> {
643        let version_range = vec![
644            PROTOCOL_VERSION_MIN.to_string(),
645            PROTOCOL_VERSION_MAX.to_string(),
646        ];
647        let req = Request::new_id(
648            self.last_id.fetch_add(1, Ordering::SeqCst),
649            "server.version",
650            vec![
651                Param::String(CLIENT_NAME.to_string()),
652                Param::StringVec(version_range),
653            ],
654        );
655        let result = self.call(req)?;
656        let response: ServerVersionRes = serde_json::from_value(result)?;
657
658        *self.protocol_version.lock()? = Some(response.protocol_version);
659        Ok(self)
660    }
661
662    fn _reader_thread(&self, until_message: Option<usize>) -> Result<serde_json::Value, Error> {
663        let mut raw_resp = String::new();
664        let resp = match self.buf_reader.try_lock() {
665            Ok(mut reader) => {
666                trace!(
667                    "Starting reader thread with `until_message` = {:?}",
668                    until_message
669                );
670
671                if let Some(until_message) = until_message {
672                    // If we are trying to start a reader thread but the corresponding sender is
673                    // missing from the map, exit immediately. We might have already received a
674                    // response for that id, but we don't know it yet. Exiting here forces the
675                    // calling code to fallback to the sender-receiver method, and it should find
676                    // a message there waiting for it.
677                    if self.waiting_map.lock()?.get(&until_message).is_none() {
678                        return Err(Error::CouldntLockReader);
679                    }
680                }
681
682                // Loop over every message
683                loop {
684                    raw_resp.clear();
685
686                    match reader.read_line(&mut raw_resp) {
687                        Ok(n_bytes_read) => {
688                            if n_bytes_read == 0 {
689                                trace!("Reached UnexpectedEof");
690                                return Err(Error::IOError(std::io::Error::new(
691                                    std::io::ErrorKind::UnexpectedEof,
692                                    "unexpected EOF",
693                                )));
694                            }
695                        }
696                        Err(e) => {
697                            let error = Arc::new(e);
698                            for (_, s) in self.waiting_map.lock().unwrap().drain() {
699                                s.send(ChannelMessage::Error(error.clone()))?;
700                            }
701                            return Err(Error::SharedIOError(error));
702                        }
703                    }
704                    trace!("<== {}", raw_resp);
705
706                    let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
707
708                    // Normally there is and id, but it's missing for spontaneous notifications
709                    // from the server
710                    let resp_id = resp["id"]
711                        .as_str()
712                        .and_then(|s| s.parse().ok())
713                        .or_else(|| resp["id"].as_u64().map(|i| i as usize));
714                    match resp_id {
715                        Some(resp_id) if until_message == Some(resp_id) => {
716                            // We have a valid id and it's exactly the one we were waiting for!
717                            trace!(
718                                "Reader thread {} received a response for its request",
719                                resp_id
720                            );
721
722                            // Remove ourselves from the "waiting map"
723                            let mut map = self.waiting_map.lock()?;
724                            map.remove(&resp_id);
725
726                            // If the map is not empty, we select a random thread to become the
727                            // new reader thread.
728                            if let Some(err) = map.values().find_map(|sender| {
729                                sender
730                                    .send(ChannelMessage::WakeUp)
731                                    .map_err(|err| {
732                                        warn!("Unable to wake up a thread, trying some other");
733                                        err
734                                    })
735                                    .err()
736                            }) {
737                                error!("All the threads has failed, giving up");
738                                return Err(err)?;
739                            }
740
741                            break Ok(resp);
742                        }
743                        Some(resp_id) => {
744                            // We have an id, but it's not our response. Notify the thread and
745                            // move on
746                            trace!("Reader thread received response for {}", resp_id);
747
748                            if let Some(sender) = self.waiting_map.lock()?.remove(&resp_id) {
749                                sender.send(ChannelMessage::Response(resp))?;
750                            } else {
751                                warn!("Missing listener for {}", resp_id);
752                            }
753                        }
754                        None => {
755                            // No id, that's probably a notification.
756                            let mut resp = resp;
757
758                            if let Some(method) = resp["method"].take().as_str() {
759                                self.handle_notification(method, resp["params"].take())?;
760                            } else {
761                                warn!("Unexpected response: {:?}", resp);
762                            }
763                        }
764                    }
765                }
766            }
767            Err(TryLockError::WouldBlock) => {
768                // If we "WouldBlock" here it means that there's already a reader thread
769                // running somewhere.
770                Err(Error::CouldntLockReader)
771            }
772            Err(TryLockError::Poisoned(e)) => Err(e)?,
773        };
774
775        let resp = resp?;
776        if let Some(err) = resp.get("error") {
777            Err(Error::Protocol(err.clone()))
778        } else {
779            Ok(resp)
780        }
781    }
782
783    fn call(&self, req: Request) -> Result<serde_json::Value, Error> {
784        // Add our listener to the map before we send the request, to make sure we don't get a
785        // reply before the receiver is added
786        let (sender, receiver) = channel();
787        self.waiting_map.lock()?.insert(req.id, sender);
788
789        // apply `authorization` token into `Request`, if any.
790        let authorization = self
791            .auth_provider
792            .as_ref()
793            .and_then(|auth_provider| auth_provider());
794
795        let req = req.with_auth(authorization);
796
797        let mut raw = serde_json::to_vec(&req)?;
798        trace!("==> {}", String::from_utf8_lossy(&raw));
799
800        raw.extend_from_slice(b"\n");
801        let mut stream = self.stream.lock()?;
802        stream.write_all(&raw)?;
803        stream.flush()?;
804        drop(stream); // release the lock
805
806        self.increment_calls();
807
808        let mut resp = match self.recv(&receiver, req.id) {
809            Ok(resp) => resp,
810            e @ Err(_) => {
811                // In case of error our sender could still be left in the map, depending on where
812                // the error happened. Just in case, try to remove it here
813                self.waiting_map.lock()?.remove(&req.id);
814                return e;
815            }
816        };
817        Ok(resp["result"].take())
818    }
819
820    fn recv(
821        &self,
822        receiver: &Receiver<ChannelMessage>,
823        req_id: usize,
824    ) -> Result<serde_json::Value, Error> {
825        loop {
826            // Try to take the lock on the reader. If we manage to do so, we'll become the reader
827            // thread until we get our reponse
828            match self._reader_thread(Some(req_id)) {
829                Ok(response) => break Ok(response),
830                Err(Error::CouldntLockReader) => {
831                    match receiver.recv()? {
832                        // Received our response, returning it
833                        ChannelMessage::Response(received) => break Ok(received),
834                        ChannelMessage::WakeUp => {
835                            // We have been woken up, this means that we should try becoming the
836                            // reader thread ourselves
837                            trace!("WakeUp for {}", req_id);
838
839                            continue;
840                        }
841                        ChannelMessage::Error(e) => {
842                            warn!("Received ChannelMessage::Error");
843
844                            break Err(Error::SharedIOError(e));
845                        }
846                    }
847                }
848                e @ Err(_) => break e,
849            }
850        }
851    }
852
853    fn handle_notification(&self, method: &str, result: serde_json::Value) -> Result<(), Error> {
854        match method {
855            "blockchain.headers.subscribe" => self.headers.lock()?.append(
856                &mut serde_json::from_value::<Vec<RawHeaderNotification>>(result)?
857                    .into_iter()
858                    .collect(),
859            ),
860            "blockchain.scripthash.subscribe" => {
861                let unserialized: ScriptNotification = serde_json::from_value(result)?;
862                let mut script_notifications = self.script_notifications.lock()?;
863
864                let queue = script_notifications
865                    .get_mut(&unserialized.scripthash)
866                    .ok_or(Error::NotSubscribed(unserialized.scripthash))?;
867
868                queue.push_back(unserialized.status);
869            }
870            _ => info!("received unknown notification for method `{}`", method),
871        }
872
873        Ok(())
874    }
875
876    pub(crate) fn internal_raw_call_with_vec(
877        &self,
878        method_name: &str,
879        params: Vec<Param>,
880    ) -> Result<serde_json::Value, Error> {
881        let req = Request::new_id(
882            self.last_id.fetch_add(1, Ordering::SeqCst),
883            method_name,
884            params,
885        );
886        let result = self.call(req)?;
887
888        Ok(result)
889    }
890
891    #[inline]
892    fn increment_calls(&self) {
893        self.calls.fetch_add(1, Ordering::SeqCst);
894    }
895}
896
897impl<T: Read + Write> ElectrumApi for RawClient<T> {
898    fn raw_call(
899        &self,
900        method_name: &str,
901        params: impl IntoIterator<Item = Param>,
902    ) -> Result<serde_json::Value, Error> {
903        self.internal_raw_call_with_vec(method_name, params.into_iter().collect())
904    }
905
906    fn batch_call(&self, batch: &Batch) -> Result<Vec<serde_json::Value>, Error> {
907        let mut raw = Vec::new();
908
909        let mut missing_responses = Vec::new();
910        let mut answers = BTreeMap::new();
911
912        // Add our listener to the map before we send the request
913
914        for (idx, (method, params)) in batch.iter().enumerate() {
915            let mut req = Request::new_id(
916                self.last_id.fetch_add(1, Ordering::SeqCst),
917                method,
918                params.to_vec(),
919            );
920
921            // Although the library DOES NOT use JSON-RPC batch arrays,
922            // It applies the `authorization` ONLY in the first `Request` of the `Batch`.
923            //
924            // JWT tokens can be 1KB+, therefore duplicating it across multiple requests adds significant overhead.
925            // It assumes the server authenticates the `Batch` by the first `Request`. If a server implementation treats
926            // each newline-delimited request independently, subsequently `Request`'s would be unauthenticated.
927            //
928            // It's a known trade-off, not a bug.
929            if idx == 0 {
930                // it should get the `authorization`, if there's an `auth_provider` available.
931                let authorization = self
932                    .auth_provider
933                    .as_ref()
934                    .and_then(|auth_provider| auth_provider());
935
936                req = req.with_auth(authorization);
937            }
938
939            // Add distinct channel to each request so when we remove our request id (and sender) from the waiting_map
940            // we can be sure that the response gets sent to the correct channel in self.recv
941            let (sender, receiver) = channel();
942            missing_responses.push((req.id, receiver));
943
944            self.waiting_map.lock()?.insert(req.id, sender);
945
946            raw.append(&mut serde_json::to_vec(&req)?);
947            raw.extend_from_slice(b"\n");
948        }
949
950        if missing_responses.is_empty() {
951            return Ok(vec![]);
952        }
953
954        trace!("==> {}", String::from_utf8_lossy(&raw));
955
956        let mut stream = self.stream.lock()?;
957        stream.write_all(&raw)?;
958        stream.flush()?;
959        drop(stream); // release the lock
960
961        self.increment_calls();
962
963        for (req_id, receiver) in missing_responses.iter() {
964            match self.recv(receiver, *req_id) {
965                Ok(mut resp) => answers.insert(req_id, resp["result"].take()),
966                Err(e) => {
967                    // In case of error our sender could still be left in the map, depending on where
968                    // the error happened. Just in case, try to remove it here
969                    warn!("got error for req_id {}: {:?}", req_id, e);
970                    warn!("removing all waiting req of this batch");
971                    let mut guard = self.waiting_map.lock()?;
972                    for (req_id, _) in missing_responses.iter() {
973                        guard.remove(req_id);
974                    }
975                    return Err(e);
976                }
977            };
978        }
979
980        Ok(answers.into_values().collect())
981    }
982
983    fn block_headers_subscribe_raw(&self) -> Result<RawHeaderNotification, Error> {
984        let req = Request::new_id(
985            self.last_id.fetch_add(1, Ordering::SeqCst),
986            "blockchain.headers.subscribe",
987            vec![],
988        );
989        let value = self.call(req)?;
990
991        Ok(serde_json::from_value(value)?)
992    }
993
994    fn block_headers_pop_raw(&self) -> Result<Option<RawHeaderNotification>, Error> {
995        Ok(self.headers.lock()?.pop_front())
996    }
997
998    fn block_header_raw(&self, height: usize) -> Result<Vec<u8>, Error> {
999        let req = Request::new_id(
1000            self.last_id.fetch_add(1, Ordering::SeqCst),
1001            "blockchain.block.header",
1002            vec![Param::Usize(height)],
1003        );
1004        let result = self.call(req)?;
1005
1006        Ok(Vec::<u8>::from_hex(
1007            result
1008                .as_str()
1009                .ok_or_else(|| Error::InvalidResponse(result.clone()))?,
1010        )?)
1011    }
1012
1013    fn block_headers(&self, start_height: usize, count: usize) -> Result<GetHeadersRes, Error> {
1014        let req = Request::new_id(
1015            self.last_id.fetch_add(1, Ordering::SeqCst),
1016            "blockchain.block.headers",
1017            vec![Param::Usize(start_height), Param::Usize(count)],
1018        );
1019        let result = self.call(req)?;
1020
1021        // Check protocol version to determine response format
1022        let is_v1_6_or_later = {
1023            let protocol_version = self.protocol_version.lock()?;
1024            protocol_version
1025                .as_ref()
1026                .map(|v| is_protocol_version_at_least(v, 1, 6))
1027                .unwrap_or(false)
1028        };
1029
1030        if is_v1_6_or_later {
1031            // v1.6+: headers field contains array of hex strings
1032            let mut deserialized: GetHeadersRes = serde_json::from_value(result)?;
1033            for header_hex in &deserialized.header_hexes {
1034                let header_bytes = Vec::<u8>::from_hex(header_hex)?;
1035                deserialized.headers.push(deserialize(&header_bytes)?);
1036            }
1037            deserialized.header_hexes.clear();
1038            Ok(deserialized)
1039        } else {
1040            // v1.4: hex field contains concatenated headers
1041            let deserialized: GetHeadersResLegacy = serde_json::from_value(result)?;
1042            let mut headers = Vec::new();
1043            for i in 0..deserialized.count {
1044                let (start, end) = (i * 80, (i + 1) * 80);
1045                headers.push(deserialize(&deserialized.raw_headers[start..end])?);
1046            }
1047            Ok(GetHeadersRes {
1048                max: deserialized.max,
1049                count: deserialized.count,
1050                header_hexes: Vec::new(),
1051                headers,
1052            })
1053        }
1054    }
1055
1056    fn estimate_fee(&self, number: usize, mode: Option<EstimationMode>) -> Result<f64, Error> {
1057        let mut params = vec![Param::Usize(number)];
1058        if let Some(mode) = mode {
1059            params.push(Param::String(mode.to_string()));
1060        }
1061        let req = Request::new_id(
1062            self.last_id.fetch_add(1, Ordering::SeqCst),
1063            "blockchain.estimatefee",
1064            params,
1065        );
1066        let result = self.call(req)?;
1067
1068        result
1069            .as_f64()
1070            .ok_or_else(|| Error::InvalidResponse(result.clone()))
1071    }
1072
1073    fn relay_fee(&self) -> Result<f64, Error> {
1074        let req = Request::new_id(
1075            self.last_id.fetch_add(1, Ordering::SeqCst),
1076            "blockchain.relayfee",
1077            vec![],
1078        );
1079        let result = self.call(req)?;
1080
1081        result
1082            .as_f64()
1083            .ok_or_else(|| Error::InvalidResponse(result.clone()))
1084    }
1085
1086    fn script_subscribe(&self, script: &Script) -> Result<Option<ScriptStatus>, Error> {
1087        let script_hash = script.to_electrum_scripthash();
1088        let mut script_notifications = self.script_notifications.lock()?;
1089
1090        if script_notifications.contains_key(&script_hash) {
1091            return Err(Error::AlreadySubscribed(script_hash));
1092        }
1093
1094        script_notifications.insert(script_hash, VecDeque::new());
1095        drop(script_notifications);
1096
1097        let req = Request::new_id(
1098            self.last_id.fetch_add(1, Ordering::SeqCst),
1099            "blockchain.scripthash.subscribe",
1100            vec![Param::String(script_hash.to_hex())],
1101        );
1102        let value = self.call(req)?;
1103
1104        Ok(serde_json::from_value(value)?)
1105    }
1106
1107    fn batch_script_subscribe<'s, I>(&self, scripts: I) -> Result<Vec<Option<ScriptStatus>>, Error>
1108    where
1109        I: IntoIterator + Clone,
1110        I::Item: Borrow<&'s Script>,
1111    {
1112        {
1113            let mut script_notifications = self.script_notifications.lock()?;
1114
1115            for script in scripts.clone() {
1116                let script_hash = script.borrow().to_electrum_scripthash();
1117                if script_notifications.contains_key(&script_hash) {
1118                    return Err(Error::AlreadySubscribed(script_hash));
1119                }
1120                script_notifications.insert(script_hash, VecDeque::new());
1121            }
1122        }
1123        impl_batch_call!(self, scripts, script_subscribe)
1124    }
1125
1126    fn script_unsubscribe(&self, script: &Script) -> Result<bool, Error> {
1127        let script_hash = script.to_electrum_scripthash();
1128        let mut script_notifications = self.script_notifications.lock()?;
1129
1130        if !script_notifications.contains_key(&script_hash) {
1131            return Err(Error::NotSubscribed(script_hash));
1132        }
1133
1134        let req = Request::new_id(
1135            self.last_id.fetch_add(1, Ordering::SeqCst),
1136            "blockchain.scripthash.unsubscribe",
1137            vec![Param::String(script_hash.to_hex())],
1138        );
1139        let value = self.call(req)?;
1140        let answer = serde_json::from_value(value)?;
1141
1142        script_notifications.remove(&script_hash);
1143
1144        Ok(answer)
1145    }
1146
1147    fn script_pop(&self, script: &Script) -> Result<Option<ScriptStatus>, Error> {
1148        let script_hash = script.to_electrum_scripthash();
1149
1150        match self.script_notifications.lock()?.get_mut(&script_hash) {
1151            None => Err(Error::NotSubscribed(script_hash)),
1152            Some(queue) => Ok(queue.pop_front()),
1153        }
1154    }
1155
1156    fn script_get_balance(&self, script: &Script) -> Result<GetBalanceRes, Error> {
1157        let params = vec![Param::String(script.to_electrum_scripthash().to_hex())];
1158        let req = Request::new_id(
1159            self.last_id.fetch_add(1, Ordering::SeqCst),
1160            "blockchain.scripthash.get_balance",
1161            params,
1162        );
1163        let result = self.call(req)?;
1164
1165        Ok(serde_json::from_value(result)?)
1166    }
1167    fn batch_script_get_balance<'s, I>(&self, scripts: I) -> Result<Vec<GetBalanceRes>, Error>
1168    where
1169        I: IntoIterator + Clone,
1170        I::Item: Borrow<&'s Script>,
1171    {
1172        impl_batch_call!(self, scripts, script_get_balance)
1173    }
1174
1175    fn script_get_history(&self, script: &Script) -> Result<Vec<GetHistoryRes>, Error> {
1176        let params = vec![Param::String(script.to_electrum_scripthash().to_hex())];
1177        let req = Request::new_id(
1178            self.last_id.fetch_add(1, Ordering::SeqCst),
1179            "blockchain.scripthash.get_history",
1180            params,
1181        );
1182        let result = self.call(req)?;
1183
1184        Ok(serde_json::from_value(result)?)
1185    }
1186    fn batch_script_get_history<'s, I>(&self, scripts: I) -> Result<Vec<Vec<GetHistoryRes>>, Error>
1187    where
1188        I: IntoIterator + Clone,
1189        I::Item: Borrow<&'s Script>,
1190    {
1191        impl_batch_call!(self, scripts, script_get_history)
1192    }
1193
1194    fn script_list_unspent(&self, script: &Script) -> Result<Vec<ListUnspentRes>, Error> {
1195        let params = vec![Param::String(script.to_electrum_scripthash().to_hex())];
1196        let req = Request::new_id(
1197            self.last_id.fetch_add(1, Ordering::SeqCst),
1198            "blockchain.scripthash.listunspent",
1199            params,
1200        );
1201        let result = self.call(req)?;
1202        let mut result: Vec<ListUnspentRes> = serde_json::from_value(result)?;
1203
1204        // This should not be necessary, since the protocol documentation says that the txs should
1205        // be "in blockchain order" (https://electrumx.readthedocs.io/en/latest/protocol-methods.html#blockchain-scripthash-listunspent).
1206        // However, elects seems to be ignoring this at the moment, so we'll sort again here just
1207        // to make sure the result is consistent.
1208        result.sort_unstable_by_key(|k| (k.height, k.tx_pos));
1209        Ok(result)
1210    }
1211
1212    fn batch_script_list_unspent<'s, I>(
1213        &self,
1214        scripts: I,
1215    ) -> Result<Vec<Vec<ListUnspentRes>>, Error>
1216    where
1217        I: IntoIterator + Clone,
1218        I::Item: Borrow<&'s Script>,
1219    {
1220        impl_batch_call!(self, scripts, script_list_unspent)
1221    }
1222
1223    fn transaction_get_raw(&self, txid: &Txid) -> Result<Vec<u8>, Error> {
1224        let params = vec![Param::String(format!("{:x}", txid))];
1225        let req = Request::new_id(
1226            self.last_id.fetch_add(1, Ordering::SeqCst),
1227            "blockchain.transaction.get",
1228            params,
1229        );
1230        let result = self.call(req)?;
1231
1232        Ok(Vec::<u8>::from_hex(
1233            result
1234                .as_str()
1235                .ok_or_else(|| Error::InvalidResponse(result.clone()))?,
1236        )?)
1237    }
1238
1239    fn batch_transaction_get_raw<'t, I>(&self, txids: I) -> Result<Vec<Vec<u8>>, Error>
1240    where
1241        I: IntoIterator + Clone,
1242        I::Item: Borrow<&'t Txid>,
1243    {
1244        let txs_string: Result<Vec<String>, Error> = impl_batch_call!(self, txids, transaction_get);
1245        txs_string?
1246            .iter()
1247            .map(|s| Ok(Vec::<u8>::from_hex(s)?))
1248            .collect()
1249    }
1250
1251    fn batch_block_header_raw<'s, I>(&self, heights: I) -> Result<Vec<Vec<u8>>, Error>
1252    where
1253        I: IntoIterator + Clone,
1254        I::Item: Borrow<u32>,
1255    {
1256        let headers_string: Result<Vec<String>, Error> =
1257            impl_batch_call!(self, heights, block_header, apply_deref);
1258        headers_string?
1259            .iter()
1260            .map(|s| Ok(Vec::<u8>::from_hex(s)?))
1261            .collect()
1262    }
1263
1264    fn batch_estimate_fee<'s, I>(&self, numbers: I) -> Result<Vec<f64>, Error>
1265    where
1266        I: IntoIterator + Clone,
1267        I::Item: Borrow<usize>,
1268    {
1269        let mut batch = Batch::default();
1270        for i in numbers {
1271            batch.estimate_fee(*i.borrow(), None);
1272        }
1273
1274        let resp = self.batch_call(&batch)?;
1275        let mut answer = Vec::new();
1276
1277        for x in resp {
1278            answer.push(serde_json::from_value(x)?);
1279        }
1280
1281        Ok(answer)
1282    }
1283
1284    fn transaction_broadcast_raw(&self, raw_tx: &[u8]) -> Result<Txid, Error> {
1285        let params = vec![Param::String(raw_tx.to_lower_hex_string())];
1286        let req = Request::new_id(
1287            self.last_id.fetch_add(1, Ordering::SeqCst),
1288            "blockchain.transaction.broadcast",
1289            params,
1290        );
1291        let result = self.call(req)?;
1292
1293        Ok(serde_json::from_value(result)?)
1294    }
1295
1296    fn transaction_broadcast_package_raw<Tx: AsRef<[u8]>>(
1297        &self,
1298        raw_txs: &[Tx],
1299    ) -> Result<BroadcastPackageRes, Error> {
1300        let hex_txs: Vec<String> = raw_txs
1301            .iter()
1302            .map(|tx| tx.as_ref().to_lower_hex_string())
1303            .collect();
1304        let params = vec![Param::StringVec(hex_txs)];
1305        let req = Request::new_id(
1306            self.last_id.fetch_add(1, Ordering::SeqCst),
1307            "blockchain.transaction.broadcast_package",
1308            params,
1309        );
1310        let result = self.call(req)?;
1311
1312        Ok(serde_json::from_value(result)?)
1313    }
1314
1315    fn transaction_get_merkle(&self, txid: &Txid, height: usize) -> Result<GetMerkleRes, Error> {
1316        let params = vec![Param::String(format!("{:x}", txid)), Param::Usize(height)];
1317        let req = Request::new_id(
1318            self.last_id.fetch_add(1, Ordering::SeqCst),
1319            "blockchain.transaction.get_merkle",
1320            params,
1321        );
1322        let result = self.call(req)?;
1323
1324        Ok(serde_json::from_value(result)?)
1325    }
1326
1327    fn batch_transaction_get_merkle<I>(
1328        &self,
1329        txids_and_heights: I,
1330    ) -> Result<Vec<GetMerkleRes>, Error>
1331    where
1332        I: IntoIterator + Clone,
1333        I::Item: Borrow<(Txid, usize)>,
1334    {
1335        impl_batch_call!(self, txids_and_heights, transaction_get_merkle)
1336    }
1337
1338    fn txid_from_pos(&self, height: usize, tx_pos: usize) -> Result<Txid, Error> {
1339        let params = vec![Param::Usize(height), Param::Usize(tx_pos)];
1340        let req = Request::new_id(
1341            self.last_id.fetch_add(1, Ordering::SeqCst),
1342            "blockchain.transaction.id_from_pos",
1343            params,
1344        );
1345        let result = self.call(req)?;
1346
1347        Ok(serde_json::from_value(result)?)
1348    }
1349
1350    fn txid_from_pos_with_merkle(
1351        &self,
1352        height: usize,
1353        tx_pos: usize,
1354    ) -> Result<TxidFromPosRes, Error> {
1355        let params = vec![
1356            Param::Usize(height),
1357            Param::Usize(tx_pos),
1358            Param::Bool(true),
1359        ];
1360        let req = Request::new_id(
1361            self.last_id.fetch_add(1, Ordering::SeqCst),
1362            "blockchain.transaction.id_from_pos",
1363            params,
1364        );
1365        let result = self.call(req)?;
1366
1367        Ok(serde_json::from_value(result)?)
1368    }
1369
1370    fn server_features(&self) -> Result<ServerFeaturesRes, Error> {
1371        let req = Request::new_id(
1372            self.last_id.fetch_add(1, Ordering::SeqCst),
1373            "server.features",
1374            vec![],
1375        );
1376        let result = self.call(req)?;
1377
1378        Ok(serde_json::from_value(result)?)
1379    }
1380
1381    fn mempool_get_info(&self) -> Result<MempoolInfoRes, Error> {
1382        let req = Request::new_id(
1383            self.last_id.fetch_add(1, Ordering::SeqCst),
1384            "mempool.get_info",
1385            vec![],
1386        );
1387        let result = self.call(req)?;
1388
1389        Ok(serde_json::from_value(result)?)
1390    }
1391
1392    fn ping(&self) -> Result<(), Error> {
1393        let req = Request::new_id(
1394            self.last_id.fetch_add(1, Ordering::SeqCst),
1395            "server.ping",
1396            vec![],
1397        );
1398        self.call(req)?;
1399
1400        Ok(())
1401    }
1402
1403    fn calls_made(&self) -> Result<usize, Error> {
1404        Ok(self.calls.load(Ordering::SeqCst))
1405    }
1406}
1407
1408#[cfg(test)]
1409mod test {
1410    use std::str::FromStr;
1411
1412    use crate::utils;
1413
1414    use super::{ElectrumSslStream, RawClient};
1415    use crate::api::ElectrumApi;
1416    use crate::config::AuthProvider;
1417
1418    // it's the default live testing electrum server, if you'd like to use a custom one set it up through
1419    // the environment variable `TEST_ELECTRUM_SERVER`.
1420    //
1421    // here's an useful list of live servers: https://1209k.com/bitcoin-eye/ele.php.
1422    const DEFAULT_TEST_ELECTRUM_SERVER: &str = "fortress.qtornado.com:443";
1423
1424    fn get_test_auth_client(
1425        authorization_provider: Option<AuthProvider>,
1426    ) -> RawClient<ElectrumSslStream> {
1427        let server = std::env::var("TEST_ELECTRUM_SERVER")
1428            .unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned());
1429
1430        RawClient::new_ssl(&*server, false, None, authorization_provider)
1431            .expect("should build the `RawClient` successfully!")
1432    }
1433
1434    fn get_test_client() -> RawClient<ElectrumSslStream> {
1435        let server = std::env::var("TEST_ELECTRUM_SERVER")
1436            .unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned());
1437
1438        RawClient::new_ssl(&*server, false, None, None)
1439            .expect("should build the `RawClient` successfully!")
1440    }
1441
1442    #[test]
1443    fn test_server_features_simple() {
1444        let client = get_test_client();
1445
1446        let resp = client.server_features().unwrap();
1447        assert_eq!(
1448            resp.genesis_hash,
1449            [
1450                0, 0, 0, 0, 0, 25, 214, 104, 156, 8, 90, 225, 101, 131, 30, 147, 79, 247, 99, 174,
1451                70, 162, 166, 193, 114, 179, 241, 182, 10, 140, 226, 111
1452            ],
1453        );
1454        assert_eq!(resp.hash_function, Some("sha256".into()));
1455        assert_eq!(resp.pruning, None);
1456    }
1457
1458    #[test]
1459    fn test_mempool_get_info() {
1460        let client = get_test_client();
1461
1462        let resp = client.mempool_get_info().unwrap();
1463        assert!(resp.mempoolminfee >= 0.0);
1464        assert!(resp.minrelaytxfee >= 0.0);
1465        assert!(resp.incrementalrelayfee >= 0.0);
1466    }
1467
1468    #[test]
1469    fn test_transaction_broadcast_package() {
1470        let client = get_test_client();
1471
1472        // Empty package should return an error or unsuccessful response
1473        let resp = client.transaction_broadcast_package_raw::<Vec<u8>>(&[]);
1474        // The server may reject an empty package with a protocol error
1475        assert!(resp.is_err() || !resp.unwrap().success);
1476    }
1477
1478    #[test]
1479    #[ignore = "depends on a live server"]
1480    fn test_batch_response_ordering() {
1481        // The electrum.blockstream.info:50001 node always sends back ordered responses which will make this always pass.
1482        // However, many servers do not, so we use one of those servers for this test.
1483        let client = get_test_client();
1484        let heights: Vec<u32> = vec![1, 4, 8, 12, 222, 6666, 12];
1485        let result_times = [
1486            1231469665, 1231470988, 1231472743, 1231474888, 1231770653, 1236456633, 1231474888,
1487        ];
1488        // Check ordering 10 times. This usually fails within 5 if ordering is incorrect.
1489        for _ in 0..10 {
1490            let results = client.batch_block_header(&heights).unwrap();
1491            for (index, result) in results.iter().enumerate() {
1492                assert_eq!(result_times[index], result.time);
1493            }
1494        }
1495    }
1496
1497    #[test]
1498    fn test_estimate_fee() {
1499        let client = get_test_client();
1500
1501        let resp = client.estimate_fee(10, None).unwrap();
1502        assert!(resp > 0.0);
1503    }
1504
1505    #[test]
1506    fn test_block_header() {
1507        let client = get_test_client();
1508
1509        let resp = client.block_header(0).unwrap();
1510        assert_eq!(resp.version, bitcoin::block::Version::ONE);
1511        assert_eq!(resp.time, 1231006505);
1512        assert_eq!(resp.nonce, 0x7c2bac1d);
1513    }
1514
1515    #[test]
1516    fn test_block_header_raw() {
1517        let client = get_test_client();
1518
1519        let resp = client.block_header_raw(0).unwrap();
1520        assert_eq!(
1521            resp,
1522            vec![
1523                1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1524                0, 0, 0, 0, 0, 0, 0, 0, 59, 163, 237, 253, 122, 123, 18, 178, 122, 199, 44, 62,
1525                103, 118, 143, 97, 127, 200, 27, 195, 136, 138, 81, 50, 58, 159, 184, 170, 75, 30,
1526                94, 74, 41, 171, 95, 73, 255, 255, 0, 29, 29, 172, 43, 124
1527            ]
1528        );
1529    }
1530
1531    #[test]
1532    fn test_block_headers() {
1533        let client = get_test_client();
1534
1535        let resp = client.block_headers(0, 4).unwrap();
1536        assert_eq!(resp.count, 4);
1537        assert_eq!(resp.max, 2016);
1538        assert_eq!(resp.headers.len(), 4);
1539
1540        assert_eq!(resp.headers[0].time, 1231006505);
1541    }
1542
1543    #[test]
1544    fn test_script_get_balance() {
1545        use std::str::FromStr;
1546
1547        let client = get_test_client();
1548
1549        // Realistically nobody will ever spend from this address, so we can expect the balance to
1550        // increase over time
1551        let addr = bitcoin::Address::from_str("1CounterpartyXXXXXXXXXXXXXXXUWLpVr")
1552            .unwrap()
1553            .assume_checked();
1554        let resp = client.script_get_balance(&addr.script_pubkey()).unwrap();
1555        assert!(resp.confirmed >= 213091301265);
1556    }
1557
1558    #[test]
1559    fn test_script_get_history() {
1560        use std::str::FromStr;
1561
1562        use bitcoin::Txid;
1563
1564        let client = get_test_client();
1565
1566        // Mt.Gox hack address
1567        let addr = bitcoin::Address::from_str("1FeexV6bAHb8ybZjqQMjJrcCrHGW9sb6uF")
1568            .unwrap()
1569            .assume_checked();
1570        let resp = client.script_get_history(&addr.script_pubkey()).unwrap();
1571
1572        assert!(resp.len() >= 328);
1573        assert_eq!(
1574            resp[0].tx_hash,
1575            Txid::from_str("e67a0550848b7932d7796aeea16ab0e48a5cfe81c4e8cca2c5b03e0416850114")
1576                .unwrap()
1577        );
1578    }
1579
1580    #[test]
1581    fn test_script_list_unspent() {
1582        use bitcoin::Txid;
1583        use std::str::FromStr;
1584
1585        let client = get_test_client();
1586
1587        // Peter todd's sha256 bounty address https://bitcointalk.org/index.php?topic=293382.0
1588        let addr = bitcoin::Address::from_str("35Snmmy3uhaer2gTboc81ayCip4m9DT4ko")
1589            .unwrap()
1590            .assume_checked();
1591        let resp = client.script_list_unspent(&addr.script_pubkey()).unwrap();
1592
1593        assert!(resp.len() >= 9);
1594        let txid = "397f12ee15f8a3d2ab25c0f6bb7d3c64d2038ca056af10dd8251b98ae0f076b0";
1595        let txid = Txid::from_str(txid).unwrap();
1596        let txs: Vec<_> = resp.iter().filter(|e| e.tx_hash == txid).collect();
1597        assert_eq!(txs.len(), 1);
1598        assert_eq!(txs[0].value, 10000000);
1599        assert_eq!(txs[0].height, 257674);
1600        assert_eq!(txs[0].tx_pos, 1);
1601    }
1602
1603    #[test]
1604    fn test_batch_script_list_unspent() {
1605        use std::str::FromStr;
1606
1607        let client = get_test_client();
1608
1609        // Peter todd's sha256 bounty address https://bitcointalk.org/index.php?topic=293382.0
1610        let script_1 = bitcoin::Address::from_str("35Snmmy3uhaer2gTboc81ayCip4m9DT4ko")
1611            .unwrap()
1612            .assume_checked()
1613            .script_pubkey();
1614
1615        let resp = client
1616            .batch_script_list_unspent(vec![script_1.as_script()])
1617            .unwrap();
1618        assert_eq!(resp.len(), 1);
1619        assert!(resp[0].len() >= 9);
1620    }
1621
1622    #[test]
1623    fn test_batch_estimate_fee() {
1624        let client = get_test_client();
1625
1626        let resp = client.batch_estimate_fee(vec![10, 20]).unwrap();
1627        assert_eq!(resp.len(), 2);
1628        assert!(resp[0] > 0.0);
1629        assert!(resp[1] > 0.0);
1630    }
1631
1632    #[test]
1633    fn test_transaction_get() {
1634        use bitcoin::{transaction, Txid};
1635
1636        let client = get_test_client();
1637
1638        let resp = client
1639            .transaction_get(
1640                &Txid::from_str("cc2ca076fd04c2aeed6d02151c447ced3d09be6fb4d4ef36cb5ed4e7a3260566")
1641                    .unwrap(),
1642            )
1643            .unwrap();
1644        assert_eq!(resp.version, transaction::Version::ONE);
1645        assert_eq!(resp.lock_time.to_consensus_u32(), 0);
1646    }
1647
1648    #[test]
1649    fn test_transaction_get_raw() {
1650        use bitcoin::Txid;
1651
1652        let client = get_test_client();
1653
1654        let resp = client
1655            .transaction_get_raw(
1656                &Txid::from_str("cc2ca076fd04c2aeed6d02151c447ced3d09be6fb4d4ef36cb5ed4e7a3260566")
1657                    .unwrap(),
1658            )
1659            .unwrap();
1660        assert_eq!(
1661            resp,
1662            vec![
1663                1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1664                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 84, 3, 240, 156, 9, 27, 77,
1665                105, 110, 101, 100, 32, 98, 121, 32, 65, 110, 116, 80, 111, 111, 108, 49, 49, 57,
1666                174, 0, 111, 32, 7, 77, 101, 40, 250, 190, 109, 109, 42, 177, 148, 141, 80, 179,
1667                217, 145, 226, 160, 130, 29, 247, 67, 88, 237, 156, 37, 83, 175, 0, 199, 166, 31,
1668                151, 119, 28, 160, 172, 238, 16, 110, 4, 0, 0, 0, 0, 0, 0, 0, 203, 236, 0, 128, 36,
1669                97, 249, 5, 255, 255, 255, 255, 3, 84, 206, 172, 42, 0, 0, 0, 0, 25, 118, 169, 20,
1670                17, 219, 228, 140, 198, 182, 23, 249, 198, 173, 175, 77, 158, 213, 246, 37, 177,
1671                199, 203, 89, 136, 172, 0, 0, 0, 0, 0, 0, 0, 0, 38, 106, 36, 170, 33, 169, 237, 46,
1672                87, 139, 206, 44, 166, 198, 188, 147, 89, 55, 115, 69, 216, 233, 133, 221, 95, 144,
1673                199, 132, 33, 255, 166, 239, 165, 235, 96, 66, 142, 105, 140, 0, 0, 0, 0, 0, 0, 0,
1674                0, 38, 106, 36, 185, 225, 27, 109, 47, 98, 29, 126, 195, 244, 90, 94, 202, 137,
1675                211, 234, 106, 41, 76, 223, 58, 4, 46, 151, 48, 9, 88, 68, 112, 161, 41, 22, 17,
1676                30, 44, 170, 1, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1677                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
1678            ]
1679        )
1680    }
1681
1682    #[test]
1683    fn test_transaction_get_merkle() {
1684        use bitcoin::Txid;
1685
1686        let client = get_test_client();
1687
1688        let txid =
1689            Txid::from_str("1f7ff3c407f33eabc8bec7d2cc230948f2249ec8e591bcf6f971ca9366c8788d")
1690                .unwrap();
1691        let resp = client.transaction_get_merkle(&txid, 630000).unwrap();
1692        assert_eq!(resp.block_height, 630000);
1693        assert_eq!(resp.pos, 68);
1694        assert_eq!(resp.merkle.len(), 12);
1695        assert_eq!(
1696            resp.merkle[0],
1697            [
1698                34, 65, 51, 64, 49, 139, 115, 189, 185, 246, 70, 225, 168, 193, 217, 195, 47, 66,
1699                179, 240, 153, 24, 114, 215, 144, 196, 212, 41, 39, 155, 246, 25
1700            ]
1701        );
1702
1703        // Check we can verify the merkle proof validity, but fail if we supply wrong data.
1704        let block_header = client.block_header(resp.block_height).unwrap();
1705        assert!(utils::validate_merkle_proof(
1706            &txid,
1707            &block_header.merkle_root,
1708            &resp
1709        ));
1710
1711        let mut fail_resp = resp.clone();
1712        fail_resp.pos = 13;
1713        assert!(!utils::validate_merkle_proof(
1714            &txid,
1715            &block_header.merkle_root,
1716            &fail_resp
1717        ));
1718
1719        let fail_block_header = client.block_header(resp.block_height + 1).unwrap();
1720        assert!(!utils::validate_merkle_proof(
1721            &txid,
1722            &fail_block_header.merkle_root,
1723            &resp
1724        ));
1725    }
1726
1727    #[test]
1728    fn test_batch_transaction_get_merkle() {
1729        use bitcoin::Txid;
1730
1731        struct TestCase {
1732            txid: Txid,
1733            block_height: usize,
1734            exp_pos: usize,
1735            exp_bytes: [u8; 32],
1736        }
1737
1738        let client = get_test_client();
1739
1740        let test_cases: Vec<TestCase> = vec![
1741            TestCase {
1742                txid: Txid::from_str(
1743                    "1f7ff3c407f33eabc8bec7d2cc230948f2249ec8e591bcf6f971ca9366c8788d",
1744                )
1745                .unwrap(),
1746                block_height: 630000,
1747                exp_pos: 68,
1748                exp_bytes: [
1749                    34, 65, 51, 64, 49, 139, 115, 189, 185, 246, 70, 225, 168, 193, 217, 195, 47,
1750                    66, 179, 240, 153, 24, 114, 215, 144, 196, 212, 41, 39, 155, 246, 25,
1751                ],
1752            },
1753            TestCase {
1754                txid: Txid::from_str(
1755                    "70a8639bc9b743c0610d1231103a2f8e99f4a25670946b91f16c55a5373b37d1",
1756                )
1757                .unwrap(),
1758                block_height: 630001,
1759                exp_pos: 25,
1760                exp_bytes: [
1761                    169, 100, 34, 99, 168, 101, 25, 168, 184, 90, 77, 50, 151, 245, 130, 101, 193,
1762                    229, 136, 128, 63, 110, 241, 19, 242, 59, 184, 137, 245, 249, 188, 110,
1763                ],
1764            },
1765            TestCase {
1766                txid: Txid::from_str(
1767                    "a0db149ace545beabbd87a8d6b20ffd6aa3b5a50e58add49a3d435f898c272cf",
1768                )
1769                .unwrap(),
1770                block_height: 840000,
1771                exp_pos: 0,
1772                exp_bytes: [
1773                    43, 184, 95, 75, 0, 75, 230, 218, 84, 247, 102, 193, 124, 30, 133, 81, 135, 50,
1774                    113, 18, 194, 49, 239, 47, 243, 94, 186, 208, 234, 103, 198, 158,
1775                ],
1776            },
1777        ];
1778
1779        let txids_and_heights: Vec<(Txid, usize)> = test_cases
1780            .iter()
1781            .map(|case| (case.txid, case.block_height))
1782            .collect();
1783
1784        let resp = client
1785            .batch_transaction_get_merkle(&txids_and_heights)
1786            .unwrap();
1787
1788        for (i, (res, test_case)) in resp.iter().zip(test_cases).enumerate() {
1789            assert_eq!(res.block_height, test_case.block_height);
1790            assert_eq!(res.pos, test_case.exp_pos);
1791            assert_eq!(res.merkle.len(), 12);
1792            assert_eq!(res.merkle[0], test_case.exp_bytes);
1793
1794            // Check we can verify the merkle proof validity, but fail if we supply wrong data.
1795            let block_header = client.block_header(res.block_height).unwrap();
1796            assert!(utils::validate_merkle_proof(
1797                &txids_and_heights[i].0,
1798                &block_header.merkle_root,
1799                res
1800            ));
1801
1802            let mut fail_res = res.clone();
1803            fail_res.pos = 13;
1804            assert!(!utils::validate_merkle_proof(
1805                &txids_and_heights[i].0,
1806                &block_header.merkle_root,
1807                &fail_res
1808            ));
1809
1810            let fail_block_header = client.block_header(res.block_height + 1).unwrap();
1811            assert!(!utils::validate_merkle_proof(
1812                &txids_and_heights[i].0,
1813                &fail_block_header.merkle_root,
1814                res
1815            ));
1816        }
1817    }
1818
1819    #[test]
1820    fn test_txid_from_pos() {
1821        use bitcoin::Txid;
1822
1823        let client = get_test_client();
1824
1825        let txid =
1826            Txid::from_str("1f7ff3c407f33eabc8bec7d2cc230948f2249ec8e591bcf6f971ca9366c8788d")
1827                .unwrap();
1828        let resp = client.txid_from_pos(630000, 68).unwrap();
1829        assert_eq!(resp, txid);
1830    }
1831
1832    #[test]
1833    fn test_txid_from_pos_with_merkle() {
1834        use bitcoin::Txid;
1835
1836        let client = get_test_client();
1837
1838        let txid =
1839            Txid::from_str("1f7ff3c407f33eabc8bec7d2cc230948f2249ec8e591bcf6f971ca9366c8788d")
1840                .unwrap();
1841        let resp = client.txid_from_pos_with_merkle(630000, 68).unwrap();
1842        assert_eq!(resp.tx_hash, txid);
1843        assert_eq!(
1844            resp.merkle[0],
1845            [
1846                34, 65, 51, 64, 49, 139, 115, 189, 185, 246, 70, 225, 168, 193, 217, 195, 47, 66,
1847                179, 240, 153, 24, 114, 215, 144, 196, 212, 41, 39, 155, 246, 25
1848            ]
1849        );
1850    }
1851
1852    #[test]
1853    fn test_ping() {
1854        let client = get_test_client();
1855        client.ping().unwrap();
1856    }
1857
1858    #[test]
1859    fn test_block_headers_subscribe() {
1860        let client = get_test_client();
1861        let resp = client.block_headers_subscribe().unwrap();
1862
1863        assert!(resp.height >= 639000);
1864    }
1865
1866    #[test]
1867    fn test_script_subscribe() {
1868        use std::str::FromStr;
1869
1870        let client = get_test_client();
1871
1872        // Mt.Gox hack address
1873        let addr = bitcoin::Address::from_str("1FeexV6bAHb8ybZjqQMjJrcCrHGW9sb6uF")
1874            .unwrap()
1875            .assume_checked();
1876
1877        // Just make sure that the call returns Ok(something)
1878        client.script_subscribe(&addr.script_pubkey()).unwrap();
1879    }
1880
1881    #[test]
1882    fn test_request_after_error() {
1883        let client = get_test_client();
1884
1885        assert!(client.transaction_broadcast_raw(&[0x00]).is_err());
1886        assert!(client.server_features().is_ok());
1887    }
1888
1889    #[test]
1890    fn test_raw_call() {
1891        use crate::types::Param;
1892
1893        let client = get_test_client();
1894
1895        let params = vec![
1896            Param::String(
1897                "cc2ca076fd04c2aeed6d02151c447ced3d09be6fb4d4ef36cb5ed4e7a3260566".to_string(),
1898            ),
1899            Param::Bool(false),
1900        ];
1901
1902        let resp = client
1903            .raw_call("blockchain.transaction.get", params)
1904            .unwrap();
1905
1906        assert_eq!(
1907            resp,
1908            "01000000000101000000000000000000000000000000000000000000000000000\
1909            0000000000000ffffffff5403f09c091b4d696e656420627920416e74506f6f6c3\
1910            13139ae006f20074d6528fabe6d6d2ab1948d50b3d991e2a0821df74358ed9c255\
1911            3af00c7a61f97771ca0acee106e0400000000000000cbec00802461f905fffffff\
1912            f0354ceac2a000000001976a91411dbe48cc6b617f9c6adaf4d9ed5f625b1c7cb5\
1913            988ac0000000000000000266a24aa21a9ed2e578bce2ca6c6bc9359377345d8e98\
1914            5dd5f90c78421ffa6efa5eb60428e698c0000000000000000266a24b9e11b6d2f6\
1915            21d7ec3f45a5eca89d3ea6a294cdf3a042e973009584470a12916111e2caa01200\
1916            000000000000000000000000000000000000000000000000000000000000000000\
1917            00000"
1918        )
1919    }
1920
1921    #[test]
1922    fn test_authorization_provider_with_client() {
1923        use std::sync::{Arc, RwLock};
1924
1925        // Track how many times the provider is called
1926        let call_count = Arc::new(RwLock::new(0));
1927        let call_count_clone = call_count.clone();
1928
1929        let auth_provider = Arc::new(move || {
1930            *call_count_clone.write().unwrap() += 1;
1931            Some("Bearer test-token-123".to_string())
1932        });
1933
1934        let client = get_test_auth_client(Some(auth_provider));
1935
1936        // Make a request - provider should be called
1937        let _ = client.server_features();
1938
1939        // Provider should have been called at least once
1940        assert!(*call_count.read().unwrap() >= 1);
1941    }
1942
1943    #[test]
1944    fn test_authorization_provider_dynamic_token_refresh() {
1945        use std::sync::{Arc, RwLock};
1946
1947        // Simulate a token that can be refreshed
1948        let token = Arc::new(RwLock::new("initial-token".to_string()));
1949        let token_clone = token.clone();
1950
1951        let auth_provider: AuthProvider =
1952            Arc::new(move || Some(token_clone.read().unwrap().clone()));
1953
1954        let client = get_test_auth_client(Some(auth_provider.clone()));
1955
1956        // Make first request with initial token
1957        let _ = client.server_features();
1958
1959        // Simulate token refresh
1960        *token.write().unwrap() = "refreshed-token".to_string();
1961
1962        // Make second request - should use the new token
1963        let _ = client.server_features();
1964
1965        // Verify the provider now returns the refreshed token
1966        assert_eq!(auth_provider(), Some("refreshed-token".to_string()));
1967    }
1968
1969    #[test]
1970    fn test_authorization_provider_returns_none() {
1971        use std::sync::Arc;
1972
1973        let auth_provider: AuthProvider = Arc::new(|| None);
1974
1975        let client = get_test_auth_client(Some(auth_provider));
1976
1977        // Should still work when provider returns None
1978        let result = client.server_features();
1979        assert!(result.is_ok());
1980    }
1981
1982    #[test]
1983    fn test_auth_provider_via_constructor() {
1984        use std::sync::Arc;
1985
1986        let auth_provider: AuthProvider = Arc::new(|| Some("Bearer test".to_string()));
1987
1988        let client = get_test_auth_client(Some(auth_provider));
1989
1990        // Verify the provider was set
1991        let result = client.server_features();
1992        assert!(result.is_ok());
1993    }
1994}