pravega_wire_protocol/
connection_factory.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10
11use crate::client_connection::{read_wirecommand, write_wirecommand};
12use crate::commands::{HelloCommand, TableKey, TableValue, OLDEST_COMPATIBLE_VERSION, WIRE_VERSION};
13use crate::connection::{Connection, TlsConnection, TokioConnection};
14use crate::error::*;
15use crate::mock_connection::MockConnection;
16use crate::wire_commands::{Replies, Requests};
17use async_trait::async_trait;
18use pravega_client_config::connection_type::MockType;
19use pravega_client_config::{connection_type::ConnectionType, ClientConfig};
20use pravega_client_shared::{NoVerifier, PravegaNodeUri, SegmentInfo};
21use pravega_connection_pool::connection_pool::{ConnectionPoolError, Manager};
22use snafu::ResultExt;
23use std::collections::HashMap;
24use std::fmt;
25use std::io::BufReader;
26use std::sync::Arc;
27use tokio::net::TcpStream;
28use tokio::sync::Mutex;
29use tokio_rustls::{rustls, webpki::DNSNameRef, TlsConnector};
30use tracing::{debug, info};
31use uuid::Uuid;
32
33/// ConnectionFactory trait is the factory used to establish the TCP connection with remote servers.
34#[async_trait]
35pub trait ConnectionFactory: Send + Sync {
36    /// establish_connection will return a Connection future that used to send and read data.
37    ///
38    /// # Example
39    ///
40    /// ```no_run
41    /// use pravega_wire_protocol::connection_factory::{ConnectionFactory, ConnectionFactoryConfig};
42    /// use pravega_client_shared::PravegaNodeUri;
43    /// use pravega_client_config::connection_type::ConnectionType;
44    /// use tokio::runtime::Runtime;
45    ///
46    /// fn main() {
47    ///   let mut rt = Runtime::new().unwrap();
48    ///   let endpoint = PravegaNodeUri::from("localhost:9090".to_string());
49    ///   let config = ConnectionFactoryConfig::new(ConnectionType::Tokio);
50    ///   let cf = ConnectionFactory::create(config);
51    ///   let connection_future = cf.establish_connection(endpoint);
52    ///   let mut connection = rt.block_on(connection_future).unwrap();
53    /// }
54    /// ```
55    async fn establish_connection(
56        &self,
57        endpoint: PravegaNodeUri,
58    ) -> Result<Box<dyn Connection>, ConnectionFactoryError>;
59}
60
61impl dyn ConnectionFactory {
62    pub fn create(config: ConnectionFactoryConfig) -> Box<dyn ConnectionFactory> {
63        match config.connection_type {
64            ConnectionType::Tokio => Box::new(TokioConnectionFactory::new(
65                config.is_tls_enabled,
66                config.certs,
67                config.disable_cert_verification,
68            )),
69            ConnectionType::Mock(mock_type) => Box::new(MockConnectionFactory::new(mock_type)),
70        }
71    }
72}
73
74struct TokioConnectionFactory {
75    tls_enabled: bool,
76    certs: Vec<String>,
77    disable_cert_verification: bool,
78}
79
80impl TokioConnectionFactory {
81    fn new(tls_enabled: bool, certs: Vec<String>, disable_cert_verification: bool) -> Self {
82        TokioConnectionFactory {
83            tls_enabled,
84            certs,
85            disable_cert_verification,
86        }
87    }
88}
89
90#[async_trait]
91impl ConnectionFactory for TokioConnectionFactory {
92    async fn establish_connection(
93        &self,
94        endpoint: PravegaNodeUri,
95    ) -> Result<Box<dyn Connection>, ConnectionFactoryError> {
96        let connection_type = ConnectionType::Tokio;
97        let uuid = Uuid::new_v4();
98        let mut tokio_connection = if self.tls_enabled {
99            info!(
100                "establish connection to segmentstore {:?} using TLS channel",
101                endpoint
102            );
103            let mut config = rustls::ClientConfig::new();
104            if self.disable_cert_verification {
105                config.dangerous().set_certificate_verifier(Arc::new(NoVerifier));
106            }
107            for cert in &self.certs {
108                let mut pem = BufReader::new(cert.as_bytes());
109
110                let res = config.root_store.add_pem_file(&mut pem);
111                match res {
112                    Ok((valid, invalid)) => {
113                        debug!(
114                            "pem file contains {} valid certs and {} invalid certs",
115                            valid, invalid
116                        );
117                    }
118                    Err(_e) => {
119                        debug!("failed to add cert files {}", cert);
120                    }
121                }
122            }
123            let connector = TlsConnector::from(Arc::new(config));
124            let stream = TcpStream::connect(endpoint.to_socket_addr())
125                .await
126                .context(Connect {
127                    connection_type,
128                    endpoint: endpoint.clone(),
129                })?;
130            // Endpoint returned by controller by default is an IP address, it is necessary to configure
131            // Pravega to return a hostname. Check pravegaservice.service.published.host.nameOrIp property.
132            let domain_name = endpoint.domain_name();
133            let domain = DNSNameRef::try_from_ascii_str(&domain_name).expect("get domain name");
134            let stream = connector
135                .connect(domain, stream)
136                .await
137                .expect("connect to tls stream");
138            // set connection as invalid initially.
139            // applications should decide whether this connection is safe to be recycled.
140            Box::new(TlsConnection {
141                uuid,
142                endpoint: endpoint.clone(),
143                stream: Some(stream),
144                can_recycle: false,
145            }) as Box<dyn Connection>
146        } else {
147            let stream = TcpStream::connect(endpoint.to_socket_addr())
148                .await
149                .context(Connect {
150                    connection_type,
151                    endpoint: endpoint.clone(),
152                })?;
153            // set connection as invalid initially.
154            // applications should decide whether this connection is safe to be recycled.
155            Box::new(TokioConnection {
156                uuid,
157                endpoint: endpoint.clone(),
158                stream: Some(stream),
159                can_recycle: false,
160            }) as Box<dyn Connection>
161        };
162        verify_connection(&mut *tokio_connection)
163            .await
164            .context(Verify {})?;
165        Ok(tokio_connection)
166    }
167}
168
169type TableSegmentIndex = HashMap<String, HashMap<TableKey, TableValue>>;
170type TableSegment = HashMap<String, Vec<(TableKey, TableValue)>>;
171
172struct MockConnectionFactory {
173    segments: Arc<Mutex<HashMap<String, SegmentInfo>>>,
174    writers: Arc<Mutex<HashMap<u128, String>>>,
175    table_segment_index: Arc<Mutex<TableSegmentIndex>>,
176    table_segment: Arc<Mutex<TableSegment>>,
177    mock_type: MockType,
178}
179
180impl MockConnectionFactory {
181    pub fn new(mock_type: MockType) -> Self {
182        MockConnectionFactory {
183            segments: Arc::new(Mutex::new(HashMap::new())),
184            writers: Arc::new(Mutex::new(HashMap::new())),
185            table_segment_index: Arc::new(Mutex::new(HashMap::new())),
186            table_segment: Arc::new(Mutex::new(HashMap::new())),
187            mock_type,
188        }
189    }
190}
191
192#[async_trait]
193impl ConnectionFactory for MockConnectionFactory {
194    async fn establish_connection(
195        &self,
196        endpoint: PravegaNodeUri,
197    ) -> Result<Box<dyn Connection>, ConnectionFactoryError> {
198        let mock = MockConnection::new(
199            endpoint,
200            self.segments.clone(),
201            self.writers.clone(),
202            self.table_segment_index.clone(),
203            self.table_segment.clone(),
204            self.mock_type,
205        );
206        Ok(Box::new(mock) as Box<dyn Connection>)
207    }
208}
209
210async fn verify_connection(conn: &mut dyn Connection) -> Result<(), ClientConnectionError> {
211    let request = Requests::Hello(HelloCommand {
212        high_version: WIRE_VERSION,
213        low_version: OLDEST_COMPATIBLE_VERSION,
214    });
215    write_wirecommand(conn, &request).await?;
216    let reply = read_wirecommand(conn).await?;
217    match reply {
218        Replies::Hello(cmd) => {
219            if cmd.low_version <= WIRE_VERSION && cmd.high_version >= WIRE_VERSION {
220                Ok(())
221            } else {
222                Err(ClientConnectionError::WrongHelloVersion {
223                    wire_version: WIRE_VERSION,
224                    oldest_compatible: OLDEST_COMPATIBLE_VERSION,
225                    wire_version_received: cmd.high_version,
226                    oldest_compatible_received: cmd.low_version,
227                })
228            }
229        }
230        _ => Err(ClientConnectionError::WrongReply { reply }),
231    }
232}
233
234/// An implementation of the Manager trait to integrate with ConnectionPool.
235/// This is for creating connections between Rust client and Segmentstore server.
236pub struct SegmentConnectionManager {
237    /// connection_factory is used to establish connection to the remote server
238    /// when there is no connection available in the internal pool.
239    connection_factory: Box<dyn ConnectionFactory>,
240
241    /// The client configuration.
242    max_connections_in_pool: u32,
243}
244
245impl SegmentConnectionManager {
246    pub fn new(connection_factory: Box<dyn ConnectionFactory>, max_connections_in_pool: u32) -> Self {
247        SegmentConnectionManager {
248            connection_factory,
249            max_connections_in_pool,
250        }
251    }
252}
253
254#[async_trait]
255impl Manager for SegmentConnectionManager {
256    type Conn = Box<dyn Connection>;
257
258    async fn establish_connection(
259        &self,
260        endpoint: PravegaNodeUri,
261    ) -> Result<Self::Conn, ConnectionPoolError> {
262        let result = self
263            .connection_factory
264            .establish_connection(endpoint.clone())
265            .await;
266
267        match result {
268            Ok(conn) => Ok(conn),
269            Err(e) => Err(ConnectionPoolError::EstablishConnection {
270                endpoint: endpoint.to_string(),
271                error_msg: format!("Could not establish connection due to {:?}", e),
272            }),
273        }
274    }
275
276    fn is_valid(&self, conn: &Self::Conn) -> bool {
277        conn.is_valid()
278    }
279
280    fn get_max_connections(&self) -> u32 {
281        self.max_connections_in_pool
282    }
283
284    fn name(&self) -> String {
285        "SegmentConnectionManager".to_owned()
286    }
287}
288
289impl fmt::Debug for SegmentConnectionManager {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        f.debug_struct("SegmentConnectionManager")
292            .field("max connections in pool", &self.max_connections_in_pool)
293            .finish()
294    }
295}
296
297/// The configuration for ConnectionFactory.
298#[derive(new)]
299pub struct ConnectionFactoryConfig {
300    connection_type: ConnectionType,
301    #[new(value = "false")]
302    is_tls_enabled: bool,
303    #[new(default)]
304    certs: Vec<String>,
305    #[new(value = "false")]
306    disable_cert_verification: bool,
307}
308
309/// ConnectionFactoryConfig can be built from ClientConfig.
310impl From<&ClientConfig> for ConnectionFactoryConfig {
311    fn from(client_config: &ClientConfig) -> Self {
312        ConnectionFactoryConfig {
313            connection_type: client_config.connection_type,
314            is_tls_enabled: client_config.is_tls_enabled,
315            certs: client_config.trustcerts.clone(),
316            disable_cert_verification: client_config.disable_cert_verification,
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::wire_commands::{Decode, Encode};
325    use log::info;
326    use pravega_client_config::connection_type::{ConnectionType, MockType};
327    use tokio::runtime::Runtime;
328
329    #[test]
330    fn test_mock_connection() {
331        info!("test mock connection factory");
332        let rt = Runtime::new().unwrap();
333        let config = ConnectionFactoryConfig::new(ConnectionType::Mock(MockType::Happy));
334        let connection_factory = ConnectionFactory::create(config);
335        let connection_future =
336            connection_factory.establish_connection(PravegaNodeUri::from("127.1.1.1:9090"));
337        let mut mock_connection = rt.block_on(connection_future).unwrap();
338
339        let request = Requests::Hello(HelloCommand {
340            high_version: 9,
341            low_version: 5,
342        })
343        .write_fields()
344        .unwrap();
345        let len = request.len();
346        rt.block_on(mock_connection.send_async(&request))
347            .expect("write to mock connection");
348        let mut buf = vec![0; len];
349        rt.block_on(mock_connection.read_async(&mut buf))
350            .expect("read from mock connection");
351        let reply = Replies::read_from(&buf).unwrap();
352        let expected = Replies::Hello(HelloCommand {
353            high_version: 9,
354            low_version: 5,
355        });
356        assert_eq!(reply, expected);
357        info!("mock connection factory test passed");
358    }
359
360    #[test]
361    #[should_panic]
362    fn test_tokio_connection() {
363        info!("test tokio connection factory");
364        let rt = Runtime::new().unwrap();
365        let config = ConnectionFactoryConfig::new(ConnectionType::Tokio);
366        let connection_factory = ConnectionFactory::create(config);
367        let connection_future =
368            connection_factory.establish_connection(PravegaNodeUri::from("127.1.1.1:9090".to_string()));
369        let mut _connection = rt.block_on(connection_future).expect("create tokio connection");
370
371        info!("tokio connection factory test passed");
372    }
373}