pravega_wire_protocol/
client_connection.rs1extern crate byteorder;
12use crate::commands::MAX_WIRECOMMAND_SIZE;
13use crate::connection::{Connection, ConnectionReadHalf, ConnectionWriteHalf};
14use crate::error::*;
15use crate::wire_commands::{Decode, Encode, Replies, Requests};
16use async_trait::async_trait;
17use byteorder::{BigEndian, ReadBytesExt};
18use pravega_connection_pool::connection_pool::PooledConnection;
19use snafu::{ensure, ResultExt};
20use std::io::Cursor;
21use std::ops::DerefMut;
22use uuid::Uuid;
23
24pub const LENGTH_FIELD_OFFSET: u32 = 4;
25pub const LENGTH_FIELD_LENGTH: u32 = 4;
26
27#[async_trait]
29pub trait ClientConnection: Send + Sync {
30 async fn read(&mut self) -> Result<Replies, ClientConnectionError>;
31 async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError>;
32 fn split(&mut self) -> (ClientConnectionReadHalf, ClientConnectionWriteHalf);
33 fn get_uuid(&self) -> Uuid;
34 fn set_failure(&mut self);
35}
36
37pub struct ClientConnectionImpl<'a> {
38 pub connection: PooledConnection<'a, Box<dyn Connection>>,
39}
40
41pub struct ClientConnectionReadHalf {
42 read_half: Box<dyn ConnectionReadHalf>,
43}
44
45#[derive(Debug)]
46pub struct ClientConnectionWriteHalf {
47 write_half: Box<dyn ConnectionWriteHalf>,
48}
49
50impl<'a> ClientConnectionImpl<'a> {
51 pub fn new(connection: PooledConnection<'a, Box<dyn Connection>>) -> Self {
52 ClientConnectionImpl { connection }
53 }
54}
55
56impl ClientConnectionReadHalf {
57 pub async fn read(&mut self) -> Result<Replies, ClientConnectionError> {
58 let mut header: Vec<u8> = vec![0; LENGTH_FIELD_OFFSET as usize + LENGTH_FIELD_LENGTH as usize];
59 self.read_half.read_async(&mut header[..]).await.context(Read {
60 part: "header".to_string(),
61 })?;
62 let mut rdr = Cursor::new(&header[4..8]);
63 let payload_length = rdr.read_u32::<BigEndian>().expect("exact size");
64 ensure!(
65 payload_length <= MAX_WIRECOMMAND_SIZE,
66 PayloadLengthTooLong {
67 payload_size: payload_length,
68 max_wirecommand_size: MAX_WIRECOMMAND_SIZE
69 }
70 );
71 let mut payload: Vec<u8> = vec![0; payload_length as usize];
72 self.read_half.read_async(&mut payload[..]).await.context(Read {
73 part: "payload".to_string(),
74 })?;
75 let concatenated = [&header[..], &payload[..]].concat();
76 let reply: Replies = Replies::read_from(&concatenated).context(DecodeCommand {})?;
77 Ok(reply)
78 }
79
80 pub fn get_id(&self) -> Uuid {
81 self.read_half.get_id()
82 }
83}
84
85impl ClientConnectionWriteHalf {
86 pub async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError> {
87 let payload = request.write_fields().context(EncodeCommand {})?;
88 self.write_half.send_async(&payload).await.context(Write {})
89 }
90
91 pub fn get_id(&self) -> Uuid {
92 self.write_half.get_id()
93 }
94}
95
96#[async_trait]
97#[allow(clippy::needless_lifetimes)] impl ClientConnection for ClientConnectionImpl<'_> {
99 async fn read(&mut self) -> Result<Replies, ClientConnectionError> {
100 read_wirecommand(&mut **self.connection.deref_mut()).await
101 }
102
103 async fn write(&mut self, request: &Requests) -> Result<(), ClientConnectionError> {
104 write_wirecommand(&mut **self.connection.deref_mut(), request).await
105 }
106
107 fn split(&mut self) -> (ClientConnectionReadHalf, ClientConnectionWriteHalf) {
108 let (r, w) = self.connection.split();
109 let reader = ClientConnectionReadHalf { read_half: r };
110 let writer = ClientConnectionWriteHalf { write_half: w };
111 (reader, writer)
112 }
113
114 fn get_uuid(&self) -> Uuid {
115 self.connection.get_uuid()
116 }
117
118 fn set_failure(&mut self) {
119 self.connection.can_recycle(false);
120 }
121}
122
123pub async fn read_wirecommand(connection: &mut dyn Connection) -> Result<Replies, ClientConnectionError> {
124 connection.can_recycle(false);
125 let mut header: Vec<u8> = vec![0; LENGTH_FIELD_OFFSET as usize + LENGTH_FIELD_LENGTH as usize];
126 connection.read_async(&mut header[..]).await.context(Read {
127 part: "header".to_string(),
128 })?;
129 let mut rdr = Cursor::new(&header[4..8]);
130 let payload_length = rdr.read_u32::<BigEndian>().expect("exact size");
131 ensure!(
132 payload_length <= MAX_WIRECOMMAND_SIZE,
133 PayloadLengthTooLong {
134 payload_size: payload_length,
135 max_wirecommand_size: MAX_WIRECOMMAND_SIZE
136 }
137 );
138 let mut payload: Vec<u8> = vec![0; payload_length as usize];
139 connection.read_async(&mut payload[..]).await.context(Read {
140 part: "payload".to_string(),
141 })?;
142 let concatenated = [&header[..], &payload[..]].concat();
143 let reply: Replies = Replies::read_from(&concatenated).context(DecodeCommand {})?;
144 connection.can_recycle(true);
145 Ok(reply)
146}
147
148pub async fn write_wirecommand(
149 connection: &mut dyn Connection,
150 request: &Requests,
151) -> Result<(), ClientConnectionError> {
152 connection.can_recycle(false);
153 let payload = request.write_fields().context(EncodeCommand {})?;
154 if let Err(e) = connection.send_async(&payload).await.context(Write {}) {
155 return Err(e);
156 } else {
157 connection.can_recycle(true);
158 }
159 Ok(())
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::commands::HelloCommand;
166 use crate::connection_factory::{ConnectionFactory, ConnectionFactoryConfig, SegmentConnectionManager};
167 use crate::wire_commands::Replies;
168 use pravega_client_config::connection_type::{ConnectionType, MockType};
169 use pravega_client_shared::PravegaNodeUri;
170 use pravega_connection_pool::connection_pool::ConnectionPool;
171 use tokio::runtime::Runtime;
172
173 #[test]
174 fn client_connection_write_and_read() {
175 let rt = Runtime::new().expect("create tokio Runtime");
176 let config = ConnectionFactoryConfig::new(ConnectionType::Mock(MockType::Happy));
177 let connection_factory = ConnectionFactory::create(config);
178 let manager = SegmentConnectionManager::new(connection_factory, 1);
179 let pool = ConnectionPool::new(manager);
180 let connection = rt
181 .block_on(pool.get_connection(PravegaNodeUri::from("127.0.0.1:9090")))
182 .expect("get connection from pool");
183
184 let mut client_connection = ClientConnectionImpl::new(connection);
185 let request = Requests::Hello(HelloCommand {
187 high_version: 9,
188 low_version: 5,
189 });
190 rt.block_on(client_connection.write(&request))
191 .expect("client connection write");
192 let reply = rt
194 .block_on(client_connection.read())
195 .expect("client connection read");
196
197 assert_eq!(
198 reply,
199 Replies::Hello(HelloCommand {
200 high_version: 9,
201 low_version: 5,
202 })
203 );
204 }
205}