pravega_wire_protocol/
client_connection.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10
11extern crate byteorder;
12use crate::commands::MAX_WIRECOMMAND_SIZE;
13use crate::connection::{Connection, ConnectionReadHalf, ConnectionWriteHalf};
14use crate::error::*;
15use crate::wire_commands::{Decode, Encode, Replies, Requests};
16use async_trait::async_trait;
17use byteorder::{BigEndian, ReadBytesExt};
18use pravega_connection_pool::connection_pool::PooledConnection;
19use snafu::{ensure, ResultExt};
20use std::io::Cursor;
21use std::ops::DerefMut;
22use uuid::Uuid;
23
24pub const LENGTH_FIELD_OFFSET: u32 = 4;
25pub const LENGTH_FIELD_LENGTH: u32 = 4;
26
27/// ClientConnection is on top of the Connection. It can read or write wirecommand instead of raw bytes.
28#[async_trait]
29pub trait ClientConnection: Send + Sync {
30    async fn read(&mut self) -> Result<Replies, ClientConnectionError>;
31    async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError>;
32    fn split(&mut self) -> (ClientConnectionReadHalf, ClientConnectionWriteHalf);
33    fn get_uuid(&self) -> Uuid;
34    fn set_failure(&mut self);
35}
36
37pub struct ClientConnectionImpl<'a> {
38    pub connection: PooledConnection<'a, Box<dyn Connection>>,
39}
40
41pub struct ClientConnectionReadHalf {
42    read_half: Box<dyn ConnectionReadHalf>,
43}
44
45#[derive(Debug)]
46pub struct ClientConnectionWriteHalf {
47    write_half: Box<dyn ConnectionWriteHalf>,
48}
49
50impl<'a> ClientConnectionImpl<'a> {
51    pub fn new(connection: PooledConnection<'a, Box<dyn Connection>>) -> Self {
52        ClientConnectionImpl { connection }
53    }
54}
55
56impl ClientConnectionReadHalf {
57    pub async fn read(&mut self) -> Result<Replies, ClientConnectionError> {
58        let mut header: Vec<u8> = vec![0; LENGTH_FIELD_OFFSET as usize + LENGTH_FIELD_LENGTH as usize];
59        self.read_half.read_async(&mut header[..]).await.context(Read {
60            part: "header".to_string(),
61        })?;
62        let mut rdr = Cursor::new(&header[4..8]);
63        let payload_length = rdr.read_u32::<BigEndian>().expect("exact size");
64        ensure!(
65            payload_length <= MAX_WIRECOMMAND_SIZE,
66            PayloadLengthTooLong {
67                payload_size: payload_length,
68                max_wirecommand_size: MAX_WIRECOMMAND_SIZE
69            }
70        );
71        let mut payload: Vec<u8> = vec![0; payload_length as usize];
72        self.read_half.read_async(&mut payload[..]).await.context(Read {
73            part: "payload".to_string(),
74        })?;
75        let concatenated = [&header[..], &payload[..]].concat();
76        let reply: Replies = Replies::read_from(&concatenated).context(DecodeCommand {})?;
77        Ok(reply)
78    }
79
80    pub fn get_id(&self) -> Uuid {
81        self.read_half.get_id()
82    }
83}
84
85impl ClientConnectionWriteHalf {
86    pub async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError> {
87        let payload = request.write_fields().context(EncodeCommand {})?;
88        self.write_half.send_async(&payload).await.context(Write {})
89    }
90
91    pub fn get_id(&self) -> Uuid {
92        self.write_half.get_id()
93    }
94}
95
96#[async_trait]
97#[allow(clippy::needless_lifetimes)] //Normally the compiler could infer lifetimes but async is throwing it for a loop.
98impl ClientConnection for ClientConnectionImpl<'_> {
99    async fn read(&mut self) -> Result<Replies, ClientConnectionError> {
100        read_wirecommand(&mut **self.connection.deref_mut()).await
101    }
102
103    async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError> {
104        write_wirecommand(&mut **self.connection.deref_mut(), request).await
105    }
106
107    fn split(&mut self) -> (ClientConnectionReadHalf, ClientConnectionWriteHalf) {
108        let (r, w) = self.connection.split();
109        let reader = ClientConnectionReadHalf { read_half: r };
110        let writer = ClientConnectionWriteHalf { write_half: w };
111        (reader, writer)
112    }
113
114    fn get_uuid(&self) -> Uuid {
115        self.connection.get_uuid()
116    }
117
118    fn set_failure(&mut self) {
119        self.connection.can_recycle(false);
120    }
121}
122
123pub async fn read_wirecommand(connection: &mut dyn Connection) -> Result<Replies, ClientConnectionError> {
124    connection.can_recycle(false);
125    let mut header: Vec<u8> = vec![0; LENGTH_FIELD_OFFSET as usize + LENGTH_FIELD_LENGTH as usize];
126    connection.read_async(&mut header[..]).await.context(Read {
127        part: "header".to_string(),
128    })?;
129    let mut rdr = Cursor::new(&header[4..8]);
130    let payload_length = rdr.read_u32::<BigEndian>().expect("exact size");
131    ensure!(
132        payload_length <= MAX_WIRECOMMAND_SIZE,
133        PayloadLengthTooLong {
134            payload_size: payload_length,
135            max_wirecommand_size: MAX_WIRECOMMAND_SIZE
136        }
137    );
138    let mut payload: Vec<u8> = vec![0; payload_length as usize];
139    connection.read_async(&mut payload[..]).await.context(Read {
140        part: "payload".to_string(),
141    })?;
142    let concatenated = [&header[..], &payload[..]].concat();
143    let reply: Replies = Replies::read_from(&concatenated).context(DecodeCommand {})?;
144    connection.can_recycle(true);
145    Ok(reply)
146}
147
148pub async fn write_wirecommand(
149    connection: &mut dyn Connection,
150    request: &Requests,
151) -> Result<(), ClientConnectionError> {
152    connection.can_recycle(false);
153    let payload = request.write_fields().context(EncodeCommand {})?;
154    if let Err(e) = connection.send_async(&payload).await.context(Write {}) {
155        return Err(e);
156    } else {
157        connection.can_recycle(true);
158    }
159    Ok(())
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::commands::HelloCommand;
166    use crate::connection_factory::{ConnectionFactory, ConnectionFactoryConfig, SegmentConnectionManager};
167    use crate::wire_commands::Replies;
168    use pravega_client_config::connection_type::{ConnectionType, MockType};
169    use pravega_client_shared::PravegaNodeUri;
170    use pravega_connection_pool::connection_pool::ConnectionPool;
171    use tokio::runtime::Runtime;
172
173    #[test]
174    fn client_connection_write_and_read() {
175        let rt = Runtime::new().expect("create tokio Runtime");
176        let config = ConnectionFactoryConfig::new(ConnectionType::Mock(MockType::Happy));
177        let connection_factory = ConnectionFactory::create(config);
178        let manager = SegmentConnectionManager::new(connection_factory, 1);
179        let pool = ConnectionPool::new(manager);
180        let connection = rt
181            .block_on(pool.get_connection(PravegaNodeUri::from("127.0.0.1:9090")))
182            .expect("get connection from pool");
183
184        let mut client_connection = ClientConnectionImpl::new(connection);
185        // write wirecommand
186        let request = Requests::Hello(HelloCommand {
187            high_version: 9,
188            low_version: 5,
189        });
190        rt.block_on(client_connection.write(&request))
191            .expect("client connection write");
192        // read wirecommand
193        let reply = rt
194            .block_on(client_connection.read())
195            .expect("client connection read");
196
197        assert_eq!(
198            reply,
199            Replies::Hello(HelloCommand {
200                high_version: 9,
201                low_version: 5,
202            })
203        );
204    }
205}