1use crate::client_connection::{read_wirecommand, write_wirecommand};
12use crate::commands::{HelloCommand, TableKey, TableValue, OLDEST_COMPATIBLE_VERSION, WIRE_VERSION};
13use crate::connection::{Connection, TlsConnection, TokioConnection};
14use crate::error::*;
15use crate::mock_connection::MockConnection;
16use crate::wire_commands::{Replies, Requests};
17use async_trait::async_trait;
18use pravega_client_config::connection_type::MockType;
19use pravega_client_config::{connection_type::ConnectionType, ClientConfig};
20use pravega_client_shared::{NoVerifier, PravegaNodeUri, SegmentInfo};
21use pravega_connection_pool::connection_pool::{ConnectionPoolError, Manager};
22use snafu::ResultExt;
23use std::collections::HashMap;
24use std::fmt;
25use std::io::BufReader;
26use std::sync::Arc;
27use tokio::net::TcpStream;
28use tokio::sync::Mutex;
29use tokio_rustls::{rustls, webpki::DNSNameRef, TlsConnector};
30use tracing::{debug, info};
31use uuid::Uuid;
32
33#[async_trait]
35pub trait ConnectionFactory: Send + Sync {
36 async fn establish_connection(
56 &self,
57 endpoint: PravegaNodeUri,
58 ) -> Result<Box<dyn Connection>, ConnectionFactoryError>;
59}
60
61impl dyn ConnectionFactory {
62 pub fn create(config: ConnectionFactoryConfig) -> Box<dyn ConnectionFactory> {
63 match config.connection_type {
64 ConnectionType::Tokio => Box::new(TokioConnectionFactory::new(
65 config.is_tls_enabled,
66 config.certs,
67 config.disable_cert_verification,
68 )),
69 ConnectionType::Mock(mock_type) => Box::new(MockConnectionFactory::new(mock_type)),
70 }
71 }
72}
73
74struct TokioConnectionFactory {
75 tls_enabled: bool,
76 certs: Vec<String>,
77 disable_cert_verification: bool,
78}
79
80impl TokioConnectionFactory {
81 fn new(tls_enabled: bool, certs: Vec<String>, disable_cert_verification: bool) -> Self {
82 TokioConnectionFactory {
83 tls_enabled,
84 certs,
85 disable_cert_verification,
86 }
87 }
88}
89
90#[async_trait]
91impl ConnectionFactory for TokioConnectionFactory {
92 async fn establish_connection(
93 &self,
94 endpoint: PravegaNodeUri,
95 ) -> Result<Box<dyn Connection>, ConnectionFactoryError> {
96 let connection_type = ConnectionType::Tokio;
97 let uuid = Uuid::new_v4();
98 let mut tokio_connection = if self.tls_enabled {
99 info!(
100 "establish connection to segmentstore {:?} using TLS channel",
101 endpoint
102 );
103 let mut config = rustls::ClientConfig::new();
104 if self.disable_cert_verification {
105 config.dangerous().set_certificate_verifier(Arc::new(NoVerifier));
106 }
107 for cert in &self.certs {
108 let mut pem = BufReader::new(cert.as_bytes());
109
110 let res = config.root_store.add_pem_file(&mut pem);
111 match res {
112 Ok((valid, invalid)) => {
113 debug!(
114 "pem file contains {} valid certs and {} invalid certs",
115 valid, invalid
116 );
117 }
118 Err(_e) => {
119 debug!("failed to add cert files {}", cert);
120 }
121 }
122 }
123 let connector = TlsConnector::from(Arc::new(config));
124 let stream = TcpStream::connect(endpoint.to_socket_addr())
125 .await
126 .context(Connect {
127 connection_type,
128 endpoint: endpoint.clone(),
129 })?;
130 let domain_name = endpoint.domain_name();
133 let domain = DNSNameRef::try_from_ascii_str(&domain_name).expect("get domain name");
134 let stream = connector
135 .connect(domain, stream)
136 .await
137 .expect("connect to tls stream");
138 Box::new(TlsConnection {
141 uuid,
142 endpoint: endpoint.clone(),
143 stream: Some(stream),
144 can_recycle: false,
145 }) as Box<dyn Connection>
146 } else {
147 let stream = TcpStream::connect(endpoint.to_socket_addr())
148 .await
149 .context(Connect {
150 connection_type,
151 endpoint: endpoint.clone(),
152 })?;
153 Box::new(TokioConnection {
156 uuid,
157 endpoint: endpoint.clone(),
158 stream: Some(stream),
159 can_recycle: false,
160 }) as Box<dyn Connection>
161 };
162 verify_connection(&mut *tokio_connection)
163 .await
164 .context(Verify {})?;
165 Ok(tokio_connection)
166 }
167}
168
169type TableSegmentIndex = HashMap<String, HashMap<TableKey, TableValue>>;
170type TableSegment = HashMap<String, Vec<(TableKey, TableValue)>>;
171
172struct MockConnectionFactory {
173 segments: Arc<Mutex<HashMap<String, SegmentInfo>>>,
174 writers: Arc<Mutex<HashMap<u128, String>>>,
175 table_segment_index: Arc<Mutex<TableSegmentIndex>>,
176 table_segment: Arc<Mutex<TableSegment>>,
177 mock_type: MockType,
178}
179
180impl MockConnectionFactory {
181 pub fn new(mock_type: MockType) -> Self {
182 MockConnectionFactory {
183 segments: Arc::new(Mutex::new(HashMap::new())),
184 writers: Arc::new(Mutex::new(HashMap::new())),
185 table_segment_index: Arc::new(Mutex::new(HashMap::new())),
186 table_segment: Arc::new(Mutex::new(HashMap::new())),
187 mock_type,
188 }
189 }
190}
191
192#[async_trait]
193impl ConnectionFactory for MockConnectionFactory {
194 async fn establish_connection(
195 &self,
196 endpoint: PravegaNodeUri,
197 ) -> Result<Box<dyn Connection>, ConnectionFactoryError> {
198 let mock = MockConnection::new(
199 endpoint,
200 self.segments.clone(),
201 self.writers.clone(),
202 self.table_segment_index.clone(),
203 self.table_segment.clone(),
204 self.mock_type,
205 );
206 Ok(Box::new(mock) as Box<dyn Connection>)
207 }
208}
209
210async fn verify_connection(conn: &mut dyn Connection) -> Result<(), ClientConnectionError> {
211 let request = Requests::Hello(HelloCommand {
212 high_version: WIRE_VERSION,
213 low_version: OLDEST_COMPATIBLE_VERSION,
214 });
215 write_wirecommand(conn, &request).await?;
216 let reply = read_wirecommand(conn).await?;
217 match reply {
218 Replies::Hello(cmd) => {
219 if cmd.low_version <= WIRE_VERSION && cmd.high_version >= WIRE_VERSION {
220 Ok(())
221 } else {
222 Err(ClientConnectionError::WrongHelloVersion {
223 wire_version: WIRE_VERSION,
224 oldest_compatible: OLDEST_COMPATIBLE_VERSION,
225 wire_version_received: cmd.high_version,
226 oldest_compatible_received: cmd.low_version,
227 })
228 }
229 }
230 _ => Err(ClientConnectionError::WrongReply { reply }),
231 }
232}
233
234pub struct SegmentConnectionManager {
237 connection_factory: Box<dyn ConnectionFactory>,
240
241 max_connections_in_pool: u32,
243}
244
245impl SegmentConnectionManager {
246 pub fn new(connection_factory: Box<dyn ConnectionFactory>, max_connections_in_pool: u32) -> Self {
247 SegmentConnectionManager {
248 connection_factory,
249 max_connections_in_pool,
250 }
251 }
252}
253
254#[async_trait]
255impl Manager for SegmentConnectionManager {
256 type Conn = Box<dyn Connection>;
257
258 async fn establish_connection(
259 &self,
260 endpoint: PravegaNodeUri,
261 ) -> Result<Self::Conn, ConnectionPoolError> {
262 let result = self
263 .connection_factory
264 .establish_connection(endpoint.clone())
265 .await;
266
267 match result {
268 Ok(conn) => Ok(conn),
269 Err(e) => Err(ConnectionPoolError::EstablishConnection {
270 endpoint: endpoint.to_string(),
271 error_msg: format!("Could not establish connection due to {:?}", e),
272 }),
273 }
274 }
275
276 fn is_valid(&self, conn: &Self::Conn) -> bool {
277 conn.is_valid()
278 }
279
280 fn get_max_connections(&self) -> u32 {
281 self.max_connections_in_pool
282 }
283
284 fn name(&self) -> String {
285 "SegmentConnectionManager".to_owned()
286 }
287}
288
289impl fmt::Debug for SegmentConnectionManager {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 f.debug_struct("SegmentConnectionManager")
292 .field("max connections in pool", &self.max_connections_in_pool)
293 .finish()
294 }
295}
296
297#[derive(new)]
299pub struct ConnectionFactoryConfig {
300 connection_type: ConnectionType,
301 #[new(value = "false")]
302 is_tls_enabled: bool,
303 #[new(default)]
304 certs: Vec<String>,
305 #[new(value = "false")]
306 disable_cert_verification: bool,
307}
308
309impl From<&ClientConfig> for ConnectionFactoryConfig {
311 fn from(client_config: &ClientConfig) -> Self {
312 ConnectionFactoryConfig {
313 connection_type: client_config.connection_type,
314 is_tls_enabled: client_config.is_tls_enabled,
315 certs: client_config.trustcerts.clone(),
316 disable_cert_verification: client_config.disable_cert_verification,
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::wire_commands::{Decode, Encode};
325 use log::info;
326 use pravega_client_config::connection_type::{ConnectionType, MockType};
327 use tokio::runtime::Runtime;
328
329 #[test]
330 fn test_mock_connection() {
331 info!("test mock connection factory");
332 let rt = Runtime::new().unwrap();
333 let config = ConnectionFactoryConfig::new(ConnectionType::Mock(MockType::Happy));
334 let connection_factory = ConnectionFactory::create(config);
335 let connection_future =
336 connection_factory.establish_connection(PravegaNodeUri::from("127.1.1.1:9090"));
337 let mut mock_connection = rt.block_on(connection_future).unwrap();
338
339 let request = Requests::Hello(HelloCommand {
340 high_version: 9,
341 low_version: 5,
342 })
343 .write_fields()
344 .unwrap();
345 let len = request.len();
346 rt.block_on(mock_connection.send_async(&request))
347 .expect("write to mock connection");
348 let mut buf = vec![0; len];
349 rt.block_on(mock_connection.read_async(&mut buf))
350 .expect("read from mock connection");
351 let reply = Replies::read_from(&buf).unwrap();
352 let expected = Replies::Hello(HelloCommand {
353 high_version: 9,
354 low_version: 5,
355 });
356 assert_eq!(reply, expected);
357 info!("mock connection factory test passed");
358 }
359
360 #[test]
361 #[should_panic]
362 fn test_tokio_connection() {
363 info!("test tokio connection factory");
364 let rt = Runtime::new().unwrap();
365 let config = ConnectionFactoryConfig::new(ConnectionType::Tokio);
366 let connection_factory = ConnectionFactory::create(config);
367 let connection_future =
368 connection_factory.establish_connection(PravegaNodeUri::from("127.1.1.1:9090".to_string()));
369 let mut _connection = rt.block_on(connection_future).expect("create tokio connection");
370
371 info!("tokio connection factory test passed");
372 }
373}