bb8_async_ssh2_lite/
impl_tokio.rs

1use core::cmp::max;
2use std::net::SocketAddr;
3use std::path::Path;
4
5use async_ssh2_lite::{AsyncSession, AsyncSftp, SessionConfiguration, TokioTcpStream};
6use async_trait::async_trait;
7use tokio_crate::sync::Semaphore;
8
9use crate::{AsyncSessionManagerError, AsyncSessionUserauthType, AsyncSftpManagerError};
10
11//
12//
13//
14#[derive(Debug)]
15#[non_exhaustive]
16pub struct AsyncSessionManagerWithTokioTcpStream {
17    pub socket_addr: SocketAddr,
18    pub configuration: Option<SessionConfiguration>,
19    pub username: String,
20    pub userauth_type: AsyncSessionUserauthType,
21    //
22    max_number_of_unauthenticated_conns: Option<Semaphore>,
23}
24
25impl Clone for AsyncSessionManagerWithTokioTcpStream {
26    fn clone(&self) -> Self {
27        Self {
28            socket_addr: self.socket_addr,
29            configuration: self.configuration.clone(),
30            username: self.username.clone(),
31            userauth_type: self.userauth_type.clone(),
32            //
33            max_number_of_unauthenticated_conns: self
34                .max_number_of_unauthenticated_conns
35                .as_ref()
36                .map(|max_number_of_unauthenticated_conns| {
37                    Semaphore::new(max_number_of_unauthenticated_conns.available_permits())
38                }),
39        }
40    }
41}
42
43impl AsyncSessionManagerWithTokioTcpStream {
44    pub fn new(
45        socket_addr: SocketAddr,
46        configuration: impl Into<Option<SessionConfiguration>>,
47        username: impl AsRef<str>,
48        userauth_type: AsyncSessionUserauthType,
49    ) -> Self {
50        Self {
51            socket_addr,
52            configuration: configuration.into(),
53            username: username.as_ref().into(),
54            userauth_type,
55            //
56            max_number_of_unauthenticated_conns: None,
57        }
58    }
59
60    pub fn set_max_number_of_unauthenticated_conns(
61        &mut self,
62        max_number_of_unauthenticated_conns: usize,
63    ) {
64        self.max_number_of_unauthenticated_conns =
65            Some(Semaphore::new(max(1, max_number_of_unauthenticated_conns)));
66    }
67
68    pub fn get_max_number_of_unauthenticated_conns(&self) -> Option<usize> {
69        self.max_number_of_unauthenticated_conns
70            .as_ref()
71            .map(|x| x.available_permits())
72    }
73}
74
75#[async_trait]
76impl bb8::ManageConnection for AsyncSessionManagerWithTokioTcpStream {
77    type Connection = AsyncSession<TokioTcpStream>;
78
79    type Error = AsyncSessionManagerError;
80
81    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
82        let semaphore_permit = if let Some(x) = self.max_number_of_unauthenticated_conns.as_ref() {
83            Some(
84                x.acquire()
85                    .await
86                    .map_err(|err| AsyncSessionManagerError::Unknown(err.to_string()))?,
87            )
88        } else {
89            None
90        };
91
92        //
93        match connect_inner(
94            self.socket_addr,
95            self.configuration.to_owned(),
96            &self.username,
97            &self.userauth_type,
98        )
99        .await
100        {
101            Ok(session) => {
102                drop(semaphore_permit);
103
104                Ok(session)
105            }
106            Err(err) => {
107                drop(semaphore_permit);
108
109                Err(err)
110            }
111        }
112    }
113
114    async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
115        Ok(())
116    }
117
118    fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
119        false
120    }
121}
122
123//
124//
125//
126#[derive(Debug, Clone)]
127pub struct AsyncSftpManagerWithTokioTcpStream(pub AsyncSessionManagerWithTokioTcpStream);
128
129#[async_trait]
130impl bb8::ManageConnection for AsyncSftpManagerWithTokioTcpStream {
131    type Connection = AsyncSftp<TokioTcpStream>;
132
133    type Error = AsyncSftpManagerError;
134
135    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
136        let session = self
137            .0
138            .connect()
139            .await
140            .map_err(AsyncSftpManagerError::AsyncSessionManagerError)?;
141
142        let sftp = session
143            .sftp()
144            .await
145            .map_err(AsyncSftpManagerError::OpenError)?;
146
147        Ok(sftp)
148    }
149
150    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
151        conn.stat(Path::new("/")).await.map_err(|e| {
152            AsyncSftpManagerError::AsyncSessionManagerError(AsyncSessionManagerError::ConnectError(
153                e,
154            ))
155        })?;
156        Ok(())
157    }
158
159    fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
160        false
161    }
162}
163
164/*
165Without AsyncChannelManagerWithTokioTcpStream, because
166Ssh2(Error { code: Session(-39), msg: "Channel can not be reused" })
167*/
168
169//
170//
171//
172async fn connect_inner(
173    socket_addr: SocketAddr,
174    configuration: Option<SessionConfiguration>,
175    username: &str,
176    userauth_type: &AsyncSessionUserauthType,
177) -> Result<AsyncSession<TokioTcpStream>, AsyncSessionManagerError> {
178    let mut session = AsyncSession::<TokioTcpStream>::connect(socket_addr, configuration)
179        .await
180        .map_err(AsyncSessionManagerError::ConnectError)?;
181
182    session
183        .handshake()
184        .await
185        .map_err(AsyncSessionManagerError::HandshakeError)?;
186
187    match userauth_type {
188        AsyncSessionUserauthType::Password { password } => {
189            session
190                .userauth_password(username, password)
191                .await
192                .map_err(AsyncSessionManagerError::UserauthError)?;
193        }
194        AsyncSessionUserauthType::Agent => {
195            session
196                .userauth_agent(username)
197                .await
198                .map_err(AsyncSessionManagerError::UserauthError)?;
199        }
200        AsyncSessionUserauthType::PubkeyFile {
201            pubkey,
202            privatekey,
203            passphrase,
204        } => {
205            session
206                .userauth_pubkey_file(
207                    username,
208                    pubkey.as_deref(),
209                    privatekey,
210                    passphrase.as_deref(),
211                )
212                .await
213                .map_err(AsyncSessionManagerError::UserauthError)?;
214        }
215    }
216
217    if !session.authenticated() {
218        return Err(AsyncSessionManagerError::AssertAuthenticated);
219    }
220
221    Ok(session)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    use std::{env, sync::Arc};
229
230    use bb8::ManageConnection as _;
231    use futures_util::future::join_all;
232    use tokio_crate as tokio;
233
234    #[tokio::test]
235    async fn test_max_number_of_unauthenticated_conns() -> Result<(), Box<dyn std::error::Error>> {
236        let host = env::var("SSH_SERVER_HOST_AND_PORT").unwrap_or_else(|_| "google.com:443".into());
237
238        let addr = match tokio::net::lookup_host(host).await {
239            Ok(mut addrs) => match addrs.next() {
240                Some(addr) => addr,
241                None => {
242                    eprintln!("lookup_host result empty");
243                    return Ok(());
244                }
245            },
246            Err(err) => {
247                eprintln!("lookup_host failed, err:{err}");
248                return Ok(());
249            }
250        };
251
252        let max_number_of_unauthenticated_conns = 4;
253
254        let mut mgr = AsyncSessionManagerWithTokioTcpStream::new(
255            addr,
256            None,
257            env::var("USER").unwrap_or_else(|_| "root".into()),
258            AsyncSessionUserauthType::Agent,
259        );
260        mgr.set_max_number_of_unauthenticated_conns(max_number_of_unauthenticated_conns);
261
262        let mgr = Arc::new(mgr);
263
264        {
265            let mgr = mgr.clone();
266            tokio::spawn(async move {
267                loop {
268                    println!(
269                        "max_number_of_unauthenticated_conns:{:?}",
270                        mgr.get_max_number_of_unauthenticated_conns()
271                    );
272                    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
273                }
274            });
275        }
276
277        let now = std::time::Instant::now();
278
279        let mut handles = vec![];
280        for _ in 0..3 {
281            for _ in 0..8 {
282                let mgr = mgr.clone();
283                let handle = tokio::spawn(async move {
284                    tokio::time::timeout(tokio::time::Duration::from_secs(5), mgr.connect()).await
285                });
286                handles.push(handle);
287            }
288            tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
289        }
290        join_all(handles).await;
291
292        assert_eq!(
293            mgr.get_max_number_of_unauthenticated_conns(),
294            Some(max_number_of_unauthenticated_conns)
295        );
296
297        let elapsed_dur = now.elapsed();
298        println!("elapsed_dur:{elapsed_dur:?}",);
299        assert!(elapsed_dur.as_millis() >= 300 * 3);
300
301        Ok(())
302    }
303}