stratum_server/
server.rs

1use crate::{
2    global::Global,
3    id_manager::IDManager,
4    route::Endpoint,
5    router::Router,
6    tcp::Handler,
7    types::{ConnectionID, GlobalVars, ReadyIndicator},
8    BanManager, ConfigManager, Connection, Result, SessionList, StratumServerBuilder,
9};
10use extended_primitives::Buffer;
11use futures::StreamExt;
12use rlimit::Resource;
13use std::{
14    net::SocketAddr,
15    sync::Arc,
16    time::{Duration, Instant},
17};
18use tokio::task::JoinSet;
19use tokio_stream::wrappers::TcpListenerStream;
20use tokio_util::sync::CancellationToken;
21use tracing::{error, info, trace, warn};
22
23pub struct StratumServer<State, CState>
24where
25    State: Clone,
26    CState: Default + Clone,
27{
28    pub(crate) id: u8,
29    pub(crate) listen_address: SocketAddr,
30    pub(crate) listener: TcpListenerStream,
31    pub(crate) state: State,
32    pub(crate) session_list: SessionList<CState>,
33    pub(crate) ban_manager: BanManager,
34    pub(crate) config_manager: ConfigManager,
35    pub(crate) router: Arc<Router<State, CState>>,
36    pub(crate) session_id_manager: IDManager,
37    pub(crate) cancel_token: CancellationToken,
38    pub(crate) global_thread_list: JoinSet<()>,
39    pub(crate) ready_indicator: ReadyIndicator,
40    pub(crate) shutdown_message: Option<Buffer>,
41    #[cfg(feature = "api")]
42    pub(crate) api: crate::api::Api,
43}
44
45impl<State, CState> StratumServer<State, CState>
46where
47    State: Clone + Send + Sync + 'static,
48    CState: Default + Clone + Send + Sync + 'static,
49{
50    pub fn builder(state: State, server_id: u8) -> StratumServerBuilder<State, CState> {
51        StratumServerBuilder::new(state, server_id)
52    }
53
54    pub fn add(&mut self, method: &str, ep: impl Endpoint<State, CState>) {
55        let router = Arc::get_mut(&mut self.router)
56            .expect("Registering routes is not possible after the Server has started");
57        router.add(method, ep);
58    }
59
60    pub fn global(&mut self, global_name: &str, ep: impl Global<State, CState>) {
61        self.global_thread_list.spawn({
62            let state = self.state.clone();
63            let session_list = self.session_list.clone();
64            let cancel_token = self.get_cancel_token();
65            let global_name = global_name.to_string();
66            async move {
67                tokio::select! {
68                    res = ep.call(state, session_list) => {
69                        if let Err(e) = res {
70                            error!(cause = ?e, "Global thread {} failed.", global_name);
71                        }
72                    }
73                    () = cancel_token.cancelled() => {
74                        info!("Global thread {} is shutting down from shutdown message.", global_name);
75                    }
76
77                }
78            }
79        });
80    }
81
82    async fn handle_incoming(&mut self) -> Result<()> {
83        info!("Listening on {}", &self.listen_address);
84
85        while let Some(stream) = self.listener.next().await {
86            let stream = match stream {
87                Ok(stream) => stream,
88                Err(e) => {
89                    error!(cause = ?e, "Unable to access stream");
90                    continue;
91                }
92            };
93
94            let id = ConnectionID::new();
95            let child_token = self.get_cancel_token();
96
97            trace!(
98                id = ?id,
99                ip = &stream.peer_addr()?.to_string(),
100                "Connection initialized",
101            );
102
103            let connection = match Connection::new(id.clone(), stream, child_token.clone()) {
104                Ok(connection) => connection,
105                Err(e) => {
106                    error!(id = ?id, cause = ?e, "Failed while constructing Connection");
107                    continue;
108                }
109            };
110
111            let handler = Handler {
112                id: id.clone(),
113                ban_manager: self.ban_manager.clone(),
114                id_manager: self.session_id_manager.clone(),
115                session_list: self.session_list.clone(),
116                router: self.router.clone(),
117                state: self.state.clone(),
118                connection_state: CState::default(),
119                config_manager: self.config_manager.clone(),
120                cancel_token: child_token,
121                global_vars: GlobalVars::new(self.id),
122                connection,
123            };
124
125            tokio::spawn(async move {
126                if let Err(err) = handler.run().await {
127                    error!(id =?id, cause = ?err, "connection error");
128                }
129            });
130        }
131
132        Ok(())
133    }
134
135    pub async fn start(&mut self) -> Result<()> {
136        init()?;
137
138        let cancel_token = self.cancel_token.clone();
139
140        #[cfg(feature = "api")]
141        let api_handle = self.api.run(cancel_token.clone())?;
142
143        tokio::select! {
144            res = self.handle_incoming() => {
145                if let Err(err) = res {
146                    error!(cause = %err, "failed to accept");
147                };
148            },
149            () = cancel_token.cancelled() => {}
150        }
151
152        let start = Instant::now();
153
154        //Session Shutdowns
155        {
156            self.session_list
157                .shutdown_msg(self.shutdown_message.clone())?;
158
159            let mut backoff = 1;
160            loop {
161                let connected_miners = self.session_list.len();
162                if connected_miners == 0 {
163                    break;
164                }
165
166                if backoff > 64 {
167                    warn!("{connected_miners} remaining, force shutting down now");
168                    self.session_list.shutdown();
169                    break;
170                }
171
172                info!("Waiting for all miners to disconnect, {connected_miners} remaining");
173                tokio::time::sleep(Duration::from_secs(backoff)).await;
174
175                backoff *= 2;
176            }
177        }
178
179        info!("Awaiting for all current globals to complete");
180        while let Some(res) = self.global_thread_list.join_next().await {
181            if let Err(err) = res {
182                error!(cause = %err, "Global thread failed to shut down gracefully.");
183            }
184        }
185
186        #[cfg(feature = "api")]
187        {
188            info!("Waiting for Api handler to finish");
189            if let Err(err) = api_handle.await {
190                error!(cause = %err, "API failed to shut down gracefully.");
191            }
192        }
193
194        info!("Shutdown complete in {} ns", start.elapsed().as_nanos());
195
196        Ok(())
197    }
198
199    pub fn get_ready_indicator(&self) -> ReadyIndicator {
200        self.ready_indicator.create_new()
201    }
202
203    // #[cfg(test)]
204    pub fn get_miner_list(&self) -> SessionList<CState> {
205        self.session_list.clone()
206    }
207
208    pub fn get_cancel_token(&self) -> CancellationToken {
209        self.cancel_token.child_token()
210    }
211
212    pub fn get_address(&self) -> SocketAddr {
213        self.listen_address
214    }
215
216    pub fn get_ban_manager(&self) -> BanManager {
217        self.ban_manager.clone()
218    }
219
220    #[cfg(feature = "api")]
221    pub fn get_api_address(&self) -> SocketAddr {
222        self.api.listen_address()
223    }
224}
225
226fn init() -> Result<()> {
227    info!("Initializing...");
228
229    //Check that the system will support what we need.
230    let (hard, soft) = rlimit::getrlimit(Resource::NOFILE)?;
231
232    info!("Current Ulimit is set to {hard} hard limit, {soft} soft limit");
233
234    info!("Initialization Complete");
235
236    Ok(())
237}
238
239//@todo
240// #[cfg(test)]
241// mod tests {
242//
243//     #[derive(Clone)]
244//     pub struct State {}
245//
246//     pub async fn handle_auth(
247//         req: StratumRequest<State>,
248//         _connection: Arc<Session<ConnectionState>>,
249//     ) -> Result<bool> {
250//         let state = req.state();
251//
252//         let login = state.auth.login().await;
253//
254//         Ok(login)
255//     }
256//
257//     #[tokio::test]
258//     async fn test_server_add() {
259//         let builder = StratumServer::builder(state, 1)
260//             .with_host("0.0.0.0")
261//             .with_port(0);
262//
263//         let mut server = builder.build().await?;
264//
265//         let address = server.get_address();
266//
267//         server.add("auth", handle_auth);
268//     }
269// }