1use pravega_client_shared::PravegaNodeUri;
12use pravega_connection_pool::connection_pool::{ConnectionPool, ConnectionPoolError};
13use pravega_wire_protocol::client_connection::{ClientConnection, ClientConnectionImpl};
14use pravega_wire_protocol::commands::{Reply, Request};
15use pravega_wire_protocol::connection_factory::SegmentConnectionManager;
16use pravega_wire_protocol::error::ClientConnectionError;
17use pravega_wire_protocol::wire_commands::{Replies, Requests};
18
19use async_trait::async_trait;
20use snafu::ResultExt;
21use snafu::Snafu;
22use std::fmt;
23use std::fmt::Debug;
24use tokio::time::error::Elapsed;
25use tokio::time::{timeout, Duration};
26
27#[derive(Debug, Snafu)]
28pub enum RawClientError {
29 #[snafu(display("Auth token has expired, refresh to try again: {}", reply))]
30 AuthTokenExpired { reply: Replies },
31
32 #[snafu(display("Failed to get connection from connection pool: {}", source))]
33 GetConnectionFromPool { source: ConnectionPoolError },
34
35 #[snafu(display("Failed to write request: {}", source))]
36 WriteRequest { source: ClientConnectionError },
37
38 #[snafu(display("Failed to read reply: {}", source))]
39 ReadReply { source: ClientConnectionError },
40
41 #[snafu(display("Reply incompatible wirecommand version: low {}, high {}", low, high))]
42 IncompatibleVersion { low: i32, high: i32 },
43
44 #[snafu(display("Request has timed out: {:?}", source))]
45 RequestTimeout { source: Elapsed },
46
47 #[snafu(display("Wrong reply id {:?} for request {:?}", reply_id, request_id))]
48 WrongReplyId { reply_id: i64, request_id: i64 },
49}
50
51impl RawClientError {
52 pub fn is_token_expired(&self) -> bool {
53 matches!(self, RawClientError::AuthTokenExpired { .. })
54 }
55}
56
57#[async_trait]
61pub(crate) trait RawClient<'a>: Send + Sync {
62 async fn send_request_with_connection(
64 &self,
65 request: &Requests,
66 client_connection: &mut ClientConnection,
67 ) -> Result<Replies, RawClientError>
68 where
69 'a: 'async_trait;
70
71 async fn send_request(&self, request: &Requests) -> Result<Replies, RawClientError>
73 where
74 'a: 'async_trait;
75
76 async fn send_setup_request(
78 &self,
79 request: &Requests,
80 ) -> Result<(Replies, Box<dyn ClientConnection + 'a>), RawClientError>
81 where
82 'a: 'async_trait;
83}
84
85pub(crate) struct RawClientImpl<'a> {
86 pool: &'a ConnectionPool<SegmentConnectionManager>,
87 endpoint: PravegaNodeUri,
88 timeout: Duration,
89}
90
91impl<'a> fmt::Debug for RawClientImpl<'a> {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 write!(f, "RawClient endpoint: {:?}", self.endpoint)
94 }
95}
96
97impl<'a> RawClientImpl<'a> {
98 pub(crate) fn new(
99 pool: &'a ConnectionPool<SegmentConnectionManager>,
100 endpoint: PravegaNodeUri,
101 timeout: Duration,
102 ) -> RawClientImpl<'a> {
103 RawClientImpl {
104 pool,
105 endpoint,
106 timeout,
107 }
108 }
109}
110
111#[allow(clippy::needless_lifetimes)]
112#[async_trait]
113impl<'a> RawClient<'a> for RawClientImpl<'a> {
114 async fn send_request_with_connection(
115 &self,
116 request: &Requests,
117 client_connection: &mut ClientConnection,
118 ) -> Result<Replies, RawClientError> {
119 client_connection.write(request).await.context(WriteRequest {})?;
120 let read_future = client_connection.read();
121 let result = timeout(self.timeout, read_future)
122 .await
123 .context(RequestTimeout {})?;
124 let reply = result.context(ReadReply {})?;
125 if reply.get_request_id() != request.get_request_id() {
126 client_connection.set_failure();
127 return Err(RawClientError::WrongReplyId {
128 reply_id: reply.get_request_id(),
129 request_id: request.get_request_id(),
130 });
131 }
132 check_auth_token_expired(&reply)?;
133 Ok(reply)
134 }
135
136 async fn send_request(&self, request: &Requests) -> Result<Replies, RawClientError> {
137 let connection = self
138 .pool
139 .get_connection(self.endpoint.clone())
140 .await
141 .context(GetConnectionFromPool {})?;
142 let mut client_connection = ClientConnectionImpl::new(connection);
143 client_connection.write(request).await.context(WriteRequest {})?;
144 let read_future = client_connection.read();
145 let result = timeout(self.timeout, read_future)
146 .await
147 .context(RequestTimeout {})?;
148 let reply = result.context(ReadReply {})?;
149 if reply.get_request_id() != request.get_request_id() {
150 client_connection.set_failure();
151 return Err(RawClientError::WrongReplyId {
152 reply_id: reply.get_request_id(),
153 request_id: request.get_request_id(),
154 });
155 }
156 check_auth_token_expired(&reply)?;
157 Ok(reply)
158 }
159
160 async fn send_setup_request(
161 &self,
162 request: &Requests,
163 ) -> Result<(Replies, Box<dyn ClientConnection + 'a>), RawClientError> {
164 let connection = self
165 .pool
166 .get_connection(self.endpoint.clone())
167 .await
168 .context(GetConnectionFromPool {})?;
169 let mut client_connection = ClientConnectionImpl::new(connection);
170 client_connection.write(request).await.context(WriteRequest {})?;
171 let read_future = client_connection.read();
172 let result = timeout(self.timeout, read_future)
173 .await
174 .context(RequestTimeout {})?;
175 let reply = result.context(ReadReply {})?;
176 if reply.get_request_id() != request.get_request_id() {
177 client_connection.set_failure();
178 return Err(RawClientError::WrongReplyId {
179 reply_id: reply.get_request_id(),
180 request_id: request.get_request_id(),
181 });
182 }
183 check_auth_token_expired(&reply)?;
184 Ok((reply, Box::new(client_connection) as Box<dyn ClientConnection>))
185 }
186}
187
188fn check_auth_token_expired(reply: &Replies) -> Result<(), RawClientError> {
189 if let Replies::AuthTokenCheckFailed(ref cmd) = reply {
190 if cmd.is_token_expired() {
191 return Err(RawClientError::AuthTokenExpired { reply: reply.clone() });
192 }
193 }
194 Ok(())
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use pravega_client_config::connection_type::ConnectionType;
201 use pravega_wire_protocol::commands::{HelloCommand, ReadSegmentCommand, SegmentReadCommand};
202 use pravega_wire_protocol::connection_factory::{ConnectionFactory, ConnectionFactoryConfig};
203 use pravega_wire_protocol::wire_commands::Encode;
204 use std::io::{Read, Write};
205 use std::net::{SocketAddr, TcpListener};
206 use std::thread;
207 use tokio::runtime::Runtime;
208
209 struct Common {
210 rt: Runtime,
211 pool: ConnectionPool<SegmentConnectionManager>,
212 }
213
214 impl Common {
215 fn new() -> Self {
216 let rt = Runtime::new().expect("create tokio Runtime");
217 let config = ConnectionFactoryConfig::new(ConnectionType::Tokio);
218 let connection_factory = ConnectionFactory::create(config);
219 let manager = SegmentConnectionManager::new(connection_factory, 1);
220 let pool = ConnectionPool::new(manager);
221 Common { rt, pool }
222 }
223 }
224
225 struct Server {
226 address: SocketAddr,
227 listener: TcpListener,
228 }
229
230 impl Server {
231 pub fn new() -> Server {
232 let listener = TcpListener::bind("127.0.0.1:0").expect("local server");
233 let address = listener.local_addr().unwrap();
234 Server { address, listener }
235 }
236
237 pub fn send_hello(&mut self) {
238 let reply = Replies::Hello(HelloCommand {
239 high_version: 9,
240 low_version: 5,
241 })
242 .write_fields()
243 .expect("serialize hello wirecommand");
244
245 for stream in self.listener.incoming() {
246 let mut stream = stream.expect("get tcp stream");
247 stream.write_all(&reply).expect("reply with hello wirecommand");
248 break;
249 }
250 }
251
252 pub fn send_hello_wrong_version(&mut self) {
253 let reply = Replies::Hello(HelloCommand {
254 high_version: 10,
255 low_version: 10,
256 })
257 .write_fields()
258 .expect("serialize hello wirecommand");
259
260 for stream in self.listener.incoming() {
261 let mut stream = stream.expect("get tcp stream");
262 stream.write_all(&reply).expect("reply with hello wirecommand");
263 break;
264 }
265 }
266 }
267
268 #[test]
269 #[should_panic] fn test_hello() {
271 let common = Common::new();
272 let mut server = Server::new();
273
274 let raw_client = RawClientImpl::new(
275 &common.pool,
276 PravegaNodeUri::from(server.address),
277 Duration::from_secs(3600),
278 );
279 let h = thread::spawn(move || {
280 server.send_hello();
281 });
282 let request = Requests::Hello(HelloCommand {
283 low_version: 5,
284 high_version: 9,
285 });
286
287 let reply = common
288 .rt
289 .block_on(raw_client.send_request(&request))
290 .expect("get reply");
291
292 assert_eq!(
293 reply,
294 Replies::Hello(HelloCommand {
295 high_version: 9,
296 low_version: 5,
297 })
298 );
299 h.join().expect("thread finished");
300 }
301
302 #[test]
303 fn test_invalid_connection() {
304 let common = Common::new();
305 let server = Server::new();
306
307 let raw_client = RawClientImpl::new(
308 &common.pool,
309 PravegaNodeUri::from(server.address),
310 Duration::from_secs(30),
311 );
312
313 let h = thread::spawn(move || {
314 let mut conn = 0;
315 for stream in server.listener.incoming() {
316 conn += 1;
317 if let Ok(mut stream) = stream {
318 if conn == 1 {
319 let mut cnt = 0;
320 let mut buf = vec![0; 8];
321 while let Ok(_size) = stream.read_exact(&mut buf) {
322 cnt += 1;
323 if cnt == 1 {
324 let reply = Replies::Hello(HelloCommand {
325 high_version: 15,
326 low_version: 5,
327 });
328 let bytes = reply.write_fields().expect("serialize reply");
329 stream.write_all(&bytes).expect("send hello");
330 } else if cnt == 2 {
331 let buf =
332 vec![0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 255, 255, 255, 255];
333 stream.write_all(&buf).expect("send invalid payload");
334 break;
335 }
336 }
337 } else {
338 let mut cnt = 0;
339 let mut buf = vec![0; 8];
340 while let Ok(_size) = stream.read_exact(&mut buf) {
341 cnt += 1;
342 if cnt == 1 {
343 let reply = Replies::Hello(HelloCommand {
344 high_version: 15,
345 low_version: 5,
346 });
347 let bytes = reply.write_fields().expect("serialize reply");
348 stream.write_all(&bytes).expect("send valid payload");
349 } else if cnt == 2 {
350 let reply = Replies::SegmentRead(SegmentReadCommand {
351 segment: "foo".to_string(),
352 offset: 0,
353 at_tail: false,
354 end_of_segment: false,
355 data: vec![0, 0, 0, 0],
356 request_id: 0,
357 });
358 let bytes = reply.write_fields().expect("serialize reply");
359 stream.write_all(&bytes).expect("send valid payload");
360 break;
361 }
362 }
363 }
364 }
365 if conn == 2 {
366 break;
367 }
368 }
369 });
370 let request = Requests::ReadSegment(ReadSegmentCommand {
371 segment: "foo".to_string(),
372 offset: 0,
373 suggested_length: 0,
374 delegation_token: "".to_string(),
375 request_id: 0,
376 });
377
378 let res = common.rt.block_on(raw_client.send_request(&request));
379 assert!(res.is_err());
381
382 let reply = common
384 .rt
385 .block_on(raw_client.send_request(&request))
386 .expect("get reply");
387
388 assert_eq!(
389 reply,
390 Replies::SegmentRead(SegmentReadCommand {
391 segment: "foo".to_string(),
392 offset: 0,
393 at_tail: false,
394 end_of_segment: false,
395 data: vec![0, 0, 0, 0],
396 request_id: 0
397 })
398 );
399 h.join().expect("thread finished");
400 }
401}