bb8_async_ssh2_lite/
impl_tokio.rs1use core::cmp::max;
2use std::net::SocketAddr;
3use std::path::Path;
4
5use async_ssh2_lite::{AsyncSession, AsyncSftp, SessionConfiguration, TokioTcpStream};
6use async_trait::async_trait;
7use tokio_crate::sync::Semaphore;
8
9use crate::{AsyncSessionManagerError, AsyncSessionUserauthType, AsyncSftpManagerError};
10
11#[derive(Debug)]
15#[non_exhaustive]
16pub struct AsyncSessionManagerWithTokioTcpStream {
17 pub socket_addr: SocketAddr,
18 pub configuration: Option<SessionConfiguration>,
19 pub username: String,
20 pub userauth_type: AsyncSessionUserauthType,
21 max_number_of_unauthenticated_conns: Option<Semaphore>,
23}
24
25impl Clone for AsyncSessionManagerWithTokioTcpStream {
26 fn clone(&self) -> Self {
27 Self {
28 socket_addr: self.socket_addr,
29 configuration: self.configuration.clone(),
30 username: self.username.clone(),
31 userauth_type: self.userauth_type.clone(),
32 max_number_of_unauthenticated_conns: self
34 .max_number_of_unauthenticated_conns
35 .as_ref()
36 .map(|max_number_of_unauthenticated_conns| {
37 Semaphore::new(max_number_of_unauthenticated_conns.available_permits())
38 }),
39 }
40 }
41}
42
43impl AsyncSessionManagerWithTokioTcpStream {
44 pub fn new(
45 socket_addr: SocketAddr,
46 configuration: impl Into<Option<SessionConfiguration>>,
47 username: impl AsRef<str>,
48 userauth_type: AsyncSessionUserauthType,
49 ) -> Self {
50 Self {
51 socket_addr,
52 configuration: configuration.into(),
53 username: username.as_ref().into(),
54 userauth_type,
55 max_number_of_unauthenticated_conns: None,
57 }
58 }
59
60 pub fn set_max_number_of_unauthenticated_conns(
61 &mut self,
62 max_number_of_unauthenticated_conns: usize,
63 ) {
64 self.max_number_of_unauthenticated_conns =
65 Some(Semaphore::new(max(1, max_number_of_unauthenticated_conns)));
66 }
67
68 pub fn get_max_number_of_unauthenticated_conns(&self) -> Option<usize> {
69 self.max_number_of_unauthenticated_conns
70 .as_ref()
71 .map(|x| x.available_permits())
72 }
73}
74
75#[async_trait]
76impl bb8::ManageConnection for AsyncSessionManagerWithTokioTcpStream {
77 type Connection = AsyncSession<TokioTcpStream>;
78
79 type Error = AsyncSessionManagerError;
80
81 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
82 let semaphore_permit = if let Some(x) = self.max_number_of_unauthenticated_conns.as_ref() {
83 Some(
84 x.acquire()
85 .await
86 .map_err(|err| AsyncSessionManagerError::Unknown(err.to_string()))?,
87 )
88 } else {
89 None
90 };
91
92 match connect_inner(
94 self.socket_addr,
95 self.configuration.to_owned(),
96 &self.username,
97 &self.userauth_type,
98 )
99 .await
100 {
101 Ok(session) => {
102 drop(semaphore_permit);
103
104 Ok(session)
105 }
106 Err(err) => {
107 drop(semaphore_permit);
108
109 Err(err)
110 }
111 }
112 }
113
114 async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
115 Ok(())
116 }
117
118 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
119 false
120 }
121}
122
123#[derive(Debug, Clone)]
127pub struct AsyncSftpManagerWithTokioTcpStream(pub AsyncSessionManagerWithTokioTcpStream);
128
129#[async_trait]
130impl bb8::ManageConnection for AsyncSftpManagerWithTokioTcpStream {
131 type Connection = AsyncSftp<TokioTcpStream>;
132
133 type Error = AsyncSftpManagerError;
134
135 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
136 let session = self
137 .0
138 .connect()
139 .await
140 .map_err(AsyncSftpManagerError::AsyncSessionManagerError)?;
141
142 let sftp = session
143 .sftp()
144 .await
145 .map_err(AsyncSftpManagerError::OpenError)?;
146
147 Ok(sftp)
148 }
149
150 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
151 conn.stat(Path::new("/")).await.map_err(|e| {
152 AsyncSftpManagerError::AsyncSessionManagerError(AsyncSessionManagerError::ConnectError(
153 e,
154 ))
155 })?;
156 Ok(())
157 }
158
159 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
160 false
161 }
162}
163
164async fn connect_inner(
173 socket_addr: SocketAddr,
174 configuration: Option<SessionConfiguration>,
175 username: &str,
176 userauth_type: &AsyncSessionUserauthType,
177) -> Result<AsyncSession<TokioTcpStream>, AsyncSessionManagerError> {
178 let mut session = AsyncSession::<TokioTcpStream>::connect(socket_addr, configuration)
179 .await
180 .map_err(AsyncSessionManagerError::ConnectError)?;
181
182 session
183 .handshake()
184 .await
185 .map_err(AsyncSessionManagerError::HandshakeError)?;
186
187 match userauth_type {
188 AsyncSessionUserauthType::Password { password } => {
189 session
190 .userauth_password(username, password)
191 .await
192 .map_err(AsyncSessionManagerError::UserauthError)?;
193 }
194 AsyncSessionUserauthType::Agent => {
195 session
196 .userauth_agent(username)
197 .await
198 .map_err(AsyncSessionManagerError::UserauthError)?;
199 }
200 AsyncSessionUserauthType::PubkeyFile {
201 pubkey,
202 privatekey,
203 passphrase,
204 } => {
205 session
206 .userauth_pubkey_file(
207 username,
208 pubkey.as_deref(),
209 privatekey,
210 passphrase.as_deref(),
211 )
212 .await
213 .map_err(AsyncSessionManagerError::UserauthError)?;
214 }
215 }
216
217 if !session.authenticated() {
218 return Err(AsyncSessionManagerError::AssertAuthenticated);
219 }
220
221 Ok(session)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 use std::{env, sync::Arc};
229
230 use bb8::ManageConnection as _;
231 use futures_util::future::join_all;
232 use tokio_crate as tokio;
233
234 #[tokio::test]
235 async fn test_max_number_of_unauthenticated_conns() -> Result<(), Box<dyn std::error::Error>> {
236 let host = env::var("SSH_SERVER_HOST_AND_PORT").unwrap_or_else(|_| "google.com:443".into());
237
238 let addr = match tokio::net::lookup_host(host).await {
239 Ok(mut addrs) => match addrs.next() {
240 Some(addr) => addr,
241 None => {
242 eprintln!("lookup_host result empty");
243 return Ok(());
244 }
245 },
246 Err(err) => {
247 eprintln!("lookup_host failed, err:{err}");
248 return Ok(());
249 }
250 };
251
252 let max_number_of_unauthenticated_conns = 4;
253
254 let mut mgr = AsyncSessionManagerWithTokioTcpStream::new(
255 addr,
256 None,
257 env::var("USER").unwrap_or_else(|_| "root".into()),
258 AsyncSessionUserauthType::Agent,
259 );
260 mgr.set_max_number_of_unauthenticated_conns(max_number_of_unauthenticated_conns);
261
262 let mgr = Arc::new(mgr);
263
264 {
265 let mgr = mgr.clone();
266 tokio::spawn(async move {
267 loop {
268 println!(
269 "max_number_of_unauthenticated_conns:{:?}",
270 mgr.get_max_number_of_unauthenticated_conns()
271 );
272 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
273 }
274 });
275 }
276
277 let now = std::time::Instant::now();
278
279 let mut handles = vec![];
280 for _ in 0..3 {
281 for _ in 0..8 {
282 let mgr = mgr.clone();
283 let handle = tokio::spawn(async move {
284 tokio::time::timeout(tokio::time::Duration::from_secs(5), mgr.connect()).await
285 });
286 handles.push(handle);
287 }
288 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
289 }
290 join_all(handles).await;
291
292 assert_eq!(
293 mgr.get_max_number_of_unauthenticated_conns(),
294 Some(max_number_of_unauthenticated_conns)
295 );
296
297 let elapsed_dur = now.elapsed();
298 println!("elapsed_dur:{elapsed_dur:?}",);
299 assert!(elapsed_dur.as_millis() >= 300 * 3);
300
301 Ok(())
302 }
303}