pravega_client/segment/
raw_client.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10
11use pravega_client_shared::PravegaNodeUri;
12use pravega_connection_pool::connection_pool::{ConnectionPool, ConnectionPoolError};
13use pravega_wire_protocol::client_connection::{ClientConnection, ClientConnectionImpl};
14use pravega_wire_protocol::commands::{Reply, Request};
15use pravega_wire_protocol::connection_factory::SegmentConnectionManager;
16use pravega_wire_protocol::error::ClientConnectionError;
17use pravega_wire_protocol::wire_commands::{Replies, Requests};
18
19use async_trait::async_trait;
20use snafu::ResultExt;
21use snafu::Snafu;
22use std::fmt;
23use std::fmt::Debug;
24use tokio::time::error::Elapsed;
25use tokio::time::{timeout, Duration};
26
27#[derive(Debug, Snafu)]
28pub enum RawClientError {
29    #[snafu(display("Auth token has expired, refresh to try again: {}", reply))]
30    AuthTokenExpired { reply: Replies },
31
32    #[snafu(display("Failed to get connection from connection pool: {}", source))]
33    GetConnectionFromPool { source: ConnectionPoolError },
34
35    #[snafu(display("Failed to write request: {}", source))]
36    WriteRequest { source: ClientConnectionError },
37
38    #[snafu(display("Failed to read reply: {}", source))]
39    ReadReply { source: ClientConnectionError },
40
41    #[snafu(display("Reply incompatible wirecommand version: low {}, high {}", low, high))]
42    IncompatibleVersion { low: i32, high: i32 },
43
44    #[snafu(display("Request has timed out: {:?}", source))]
45    RequestTimeout { source: Elapsed },
46
47    #[snafu(display("Wrong reply id {:?} for request {:?}", reply_id, request_id))]
48    WrongReplyId { reply_id: i64, request_id: i64 },
49}
50
51impl RawClientError {
52    pub fn is_token_expired(&self) -> bool {
53        matches!(self, RawClientError::AuthTokenExpired { .. })
54    }
55}
56
57// RawClient is on top of the ClientConnection. It provides methods that take
58// Request enums and return Reply enums asynchronously. It has logic to process some of the replies from
59// server and return the processed result to caller.
60#[async_trait]
61pub(crate) trait RawClient<'a>: Send + Sync {
62    // Asynchronously send a request to the server and receive a response.
63    async fn send_request_with_connection(
64        &self,
65        request: &Requests,
66        client_connection: &mut ClientConnection,
67    ) -> Result<Replies, RawClientError>
68    where
69        'a: 'async_trait;
70
71    // Asynchronously send a request to the server and receive a response.
72    async fn send_request(&self, request: &Requests) -> Result<Replies, RawClientError>
73    where
74        'a: 'async_trait;
75
76    // Asynchronously send a request to the server and receive a response and return the connection to the caller.
77    async fn send_setup_request(
78        &self,
79        request: &Requests,
80    ) -> Result<(Replies, Box<dyn ClientConnection + 'a>), RawClientError>
81    where
82        'a: 'async_trait;
83}
84
85pub(crate) struct RawClientImpl<'a> {
86    pool: &'a ConnectionPool<SegmentConnectionManager>,
87    endpoint: PravegaNodeUri,
88    timeout: Duration,
89}
90
91impl<'a> fmt::Debug for RawClientImpl<'a> {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(f, "RawClient endpoint: {:?}", self.endpoint)
94    }
95}
96
97impl<'a> RawClientImpl<'a> {
98    pub(crate) fn new(
99        pool: &'a ConnectionPool<SegmentConnectionManager>,
100        endpoint: PravegaNodeUri,
101        timeout: Duration,
102    ) -> RawClientImpl<'a> {
103        RawClientImpl {
104            pool,
105            endpoint,
106            timeout,
107        }
108    }
109}
110
111#[allow(clippy::needless_lifetimes)]
112#[async_trait]
113impl<'a> RawClient<'a> for RawClientImpl<'a> {
114    async fn send_request_with_connection(
115        &self,
116        request: &Requests,
117        client_connection: &mut ClientConnection,
118    ) -> Result<Replies, RawClientError> {
119        client_connection.write(request).await.context(WriteRequest {})?;
120        let read_future = client_connection.read();
121        let result = timeout(self.timeout, read_future)
122            .await
123            .context(RequestTimeout {})?;
124        let reply = result.context(ReadReply {})?;
125        if reply.get_request_id() != request.get_request_id() {
126            client_connection.set_failure();
127            return Err(RawClientError::WrongReplyId {
128                reply_id: reply.get_request_id(),
129                request_id: request.get_request_id(),
130            });
131        }
132        check_auth_token_expired(&reply)?;
133        Ok(reply)
134    }
135
136    async fn send_request(&self, request: &Requests) -> Result<Replies, RawClientError> {
137        let connection = self
138            .pool
139            .get_connection(self.endpoint.clone())
140            .await
141            .context(GetConnectionFromPool {})?;
142        let mut client_connection = ClientConnectionImpl::new(connection);
143        client_connection.write(request).await.context(WriteRequest {})?;
144        let read_future = client_connection.read();
145        let result = timeout(self.timeout, read_future)
146            .await
147            .context(RequestTimeout {})?;
148        let reply = result.context(ReadReply {})?;
149        if reply.get_request_id() != request.get_request_id() {
150            client_connection.set_failure();
151            return Err(RawClientError::WrongReplyId {
152                reply_id: reply.get_request_id(),
153                request_id: request.get_request_id(),
154            });
155        }
156        check_auth_token_expired(&reply)?;
157        Ok(reply)
158    }
159
160    async fn send_setup_request(
161        &self,
162        request: &Requests,
163    ) -> Result<(Replies, Box<dyn ClientConnection + 'a>), RawClientError> {
164        let connection = self
165            .pool
166            .get_connection(self.endpoint.clone())
167            .await
168            .context(GetConnectionFromPool {})?;
169        let mut client_connection = ClientConnectionImpl::new(connection);
170        client_connection.write(request).await.context(WriteRequest {})?;
171        let read_future = client_connection.read();
172        let result = timeout(self.timeout, read_future)
173            .await
174            .context(RequestTimeout {})?;
175        let reply = result.context(ReadReply {})?;
176        if reply.get_request_id() != request.get_request_id() {
177            client_connection.set_failure();
178            return Err(RawClientError::WrongReplyId {
179                reply_id: reply.get_request_id(),
180                request_id: request.get_request_id(),
181            });
182        }
183        check_auth_token_expired(&reply)?;
184        Ok((reply, Box::new(client_connection) as Box<dyn ClientConnection>))
185    }
186}
187
188fn check_auth_token_expired(reply: &Replies) -> Result<(), RawClientError> {
189    if let Replies::AuthTokenCheckFailed(ref cmd) = reply {
190        if cmd.is_token_expired() {
191            return Err(RawClientError::AuthTokenExpired { reply: reply.clone() });
192        }
193    }
194    Ok(())
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use pravega_client_config::connection_type::ConnectionType;
201    use pravega_wire_protocol::commands::{HelloCommand, ReadSegmentCommand, SegmentReadCommand};
202    use pravega_wire_protocol::connection_factory::{ConnectionFactory, ConnectionFactoryConfig};
203    use pravega_wire_protocol::wire_commands::Encode;
204    use std::io::{Read, Write};
205    use std::net::{SocketAddr, TcpListener};
206    use std::thread;
207    use tokio::runtime::Runtime;
208
209    struct Common {
210        rt: Runtime,
211        pool: ConnectionPool<SegmentConnectionManager>,
212    }
213
214    impl Common {
215        fn new() -> Self {
216            let rt = Runtime::new().expect("create tokio Runtime");
217            let config = ConnectionFactoryConfig::new(ConnectionType::Tokio);
218            let connection_factory = ConnectionFactory::create(config);
219            let manager = SegmentConnectionManager::new(connection_factory, 1);
220            let pool = ConnectionPool::new(manager);
221            Common { rt, pool }
222        }
223    }
224
225    struct Server {
226        address: SocketAddr,
227        listener: TcpListener,
228    }
229
230    impl Server {
231        pub fn new() -> Server {
232            let listener = TcpListener::bind("127.0.0.1:0").expect("local server");
233            let address = listener.local_addr().unwrap();
234            Server { address, listener }
235        }
236
237        pub fn send_hello(&mut self) {
238            let reply = Replies::Hello(HelloCommand {
239                high_version: 9,
240                low_version: 5,
241            })
242            .write_fields()
243            .expect("serialize hello wirecommand");
244
245            for stream in self.listener.incoming() {
246                let mut stream = stream.expect("get tcp stream");
247                stream.write_all(&reply).expect("reply with hello wirecommand");
248                break;
249            }
250        }
251
252        pub fn send_hello_wrong_version(&mut self) {
253            let reply = Replies::Hello(HelloCommand {
254                high_version: 10,
255                low_version: 10,
256            })
257            .write_fields()
258            .expect("serialize hello wirecommand");
259
260            for stream in self.listener.incoming() {
261                let mut stream = stream.expect("get tcp stream");
262                stream.write_all(&reply).expect("reply with hello wirecommand");
263                break;
264            }
265        }
266    }
267
268    #[test]
269    #[should_panic] // since connection verify will panic
270    fn test_hello() {
271        let common = Common::new();
272        let mut server = Server::new();
273
274        let raw_client = RawClientImpl::new(
275            &common.pool,
276            PravegaNodeUri::from(server.address),
277            Duration::from_secs(3600),
278        );
279        let h = thread::spawn(move || {
280            server.send_hello();
281        });
282        let request = Requests::Hello(HelloCommand {
283            low_version: 5,
284            high_version: 9,
285        });
286
287        let reply = common
288            .rt
289            .block_on(raw_client.send_request(&request))
290            .expect("get reply");
291
292        assert_eq!(
293            reply,
294            Replies::Hello(HelloCommand {
295                high_version: 9,
296                low_version: 5,
297            })
298        );
299        h.join().expect("thread finished");
300    }
301
302    #[test]
303    fn test_invalid_connection() {
304        let common = Common::new();
305        let server = Server::new();
306
307        let raw_client = RawClientImpl::new(
308            &common.pool,
309            PravegaNodeUri::from(server.address),
310            Duration::from_secs(30),
311        );
312
313        let h = thread::spawn(move || {
314            let mut conn = 0;
315            for stream in server.listener.incoming() {
316                conn += 1;
317                if let Ok(mut stream) = stream {
318                    if conn == 1 {
319                        let mut cnt = 0;
320                        let mut buf = vec![0; 8];
321                        while let Ok(_size) = stream.read_exact(&mut buf) {
322                            cnt += 1;
323                            if cnt == 1 {
324                                let reply = Replies::Hello(HelloCommand {
325                                    high_version: 15,
326                                    low_version: 5,
327                                });
328                                let bytes = reply.write_fields().expect("serialize reply");
329                                stream.write_all(&bytes).expect("send hello");
330                            } else if cnt == 2 {
331                                let buf =
332                                    vec![0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 255, 255, 255, 255];
333                                stream.write_all(&buf).expect("send invalid payload");
334                                break;
335                            }
336                        }
337                    } else {
338                        let mut cnt = 0;
339                        let mut buf = vec![0; 8];
340                        while let Ok(_size) = stream.read_exact(&mut buf) {
341                            cnt += 1;
342                            if cnt == 1 {
343                                let reply = Replies::Hello(HelloCommand {
344                                    high_version: 15,
345                                    low_version: 5,
346                                });
347                                let bytes = reply.write_fields().expect("serialize reply");
348                                stream.write_all(&bytes).expect("send valid payload");
349                            } else if cnt == 2 {
350                                let reply = Replies::SegmentRead(SegmentReadCommand {
351                                    segment: "foo".to_string(),
352                                    offset: 0,
353                                    at_tail: false,
354                                    end_of_segment: false,
355                                    data: vec![0, 0, 0, 0],
356                                    request_id: 0,
357                                });
358                                let bytes = reply.write_fields().expect("serialize reply");
359                                stream.write_all(&bytes).expect("send valid payload");
360                                break;
361                            }
362                        }
363                    }
364                }
365                if conn == 2 {
366                    break;
367                }
368            }
369        });
370        let request = Requests::ReadSegment(ReadSegmentCommand {
371            segment: "foo".to_string(),
372            offset: 0,
373            suggested_length: 0,
374            delegation_token: "".to_string(),
375            request_id: 0,
376        });
377
378        let res = common.rt.block_on(raw_client.send_request(&request));
379        // payload length too long
380        assert!(res.is_err());
381
382        // retry
383        let reply = common
384            .rt
385            .block_on(raw_client.send_request(&request))
386            .expect("get reply");
387
388        assert_eq!(
389            reply,
390            Replies::SegmentRead(SegmentReadCommand {
391                segment: "foo".to_string(),
392                offset: 0,
393                at_tail: false,
394                end_of_segment: false,
395                data: vec![0, 0, 0, 0],
396                request_id: 0
397            })
398        );
399        h.join().expect("thread finished");
400    }
401}