rift/protocol/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
19use std::cell::RefCell;
20use std::convert::From;
21use std::io::{Read, Write};
22use std::rc::Rc;
23use try_from::TryFrom;
24
25use ::{ProtocolError, ProtocolErrorKind};
26use ::transport::TTransport;
27use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType};
28use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
29
30const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000;
31
32/// Reads messages encoded using the Thrift simple binary encoding.
33pub struct TBinaryInputProtocol {
34    strict: bool,
35    transport: Rc<RefCell<Box<TTransport>>>,
36}
37
38impl TBinaryInputProtocol {
39    pub fn new(strict: bool, transport: Rc<RefCell<Box<TTransport>>>) -> TBinaryInputProtocol {
40        TBinaryInputProtocol { strict: strict, transport: transport }
41    }
42}
43
44impl TInputProtocol for TBinaryInputProtocol {
45    fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
46        let mut first_bytes = vec![0; 4];
47        try!(self.transport.borrow_mut().read_exact(&mut first_bytes[..]));
48
49        // the thrift version header is intentionally negative
50        // so the first check we'll do is see if the sign bit is set
51        // and if so - assume it's the protocol-version header
52        if first_bytes[0] >= 8 {
53            // apparently we got a protocol-version header - check
54            // it, and if it matches, read the rest of the fields
55            if first_bytes[0..2] != [0x80, 0x01] {
56                Err(
57                    ::Error::Protocol(
58                        ProtocolError {
59                            kind: ProtocolErrorKind::BadVersion,
60                            message: format!("received bad version: {:?}", &first_bytes[0..2]),
61                        }
62                    )
63                )
64            } else {
65                let message_type: TMessageType = try!(TryFrom::try_from(first_bytes[3]));
66                let name = try!(self.read_string());
67                let sequence_number = try!(self.read_i32());
68                Ok(TMessageIdentifier::new(name, message_type, sequence_number))
69            }
70        } else {
71            // apparently we didn't get a protocol-version header,
72            // which happens if the sender is not using the strict protocol
73            if self.strict {
74                // we're in strict mode however, and that always
75                // requires the protocol-version header to be written first
76                Err(
77                    ::Error::Protocol(
78                        ProtocolError {
79                            kind: ProtocolErrorKind::BadVersion,
80                            message: format!("received bad version: {:?}", &first_bytes[0..2]),
81                        }
82                    )
83                )
84            } else {
85                // in the non-strict version the first message field
86                // is the message name. strings (byte arrays) are length-prefixed,
87                // so we've just read the length in the first 4 bytes
88                let name_size = BigEndian::read_i32(&first_bytes) as usize;
89                let mut name_buf: Vec<u8> = Vec::with_capacity(name_size);
90                try!(self.transport.borrow_mut().read_exact(&mut name_buf));
91                let name = try!(String::from_utf8(name_buf));
92
93                // read the rest of the fields
94                let message_type: TMessageType = try!(self.read_byte().and_then(TryFrom::try_from));
95                let sequence_number = try!(self.read_i32());
96                Ok(TMessageIdentifier::new(name, message_type, sequence_number))
97            }
98        }
99    }
100
101    fn read_message_end(&mut self) -> ::Result<()> {
102        Ok(())
103    }
104
105    fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
106        Ok(None)
107    }
108
109    fn read_struct_end(&mut self) -> ::Result<()> {
110        Ok(())
111    }
112
113    fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
114        let field_type_byte = try!(self.read_byte());
115        let field_type = try!(field_type_from_u8(field_type_byte));
116        let id = try!(match field_type {
117            TType::Stop => Ok(0),
118            _ => self.read_i16()
119        });
120        Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id))
121    }
122
123    fn read_field_end(&mut self) -> ::Result<()> {
124        Ok(())
125    }
126
127    fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
128        let num_bytes = try!(self.transport.borrow_mut().read_i32::<BigEndian>()) as usize;
129        let mut buf = vec![0u8; num_bytes];
130        self.transport.borrow_mut().read_exact(&mut buf).map(|_| buf).map_err(From::from)
131    }
132
133    fn read_bool(&mut self) -> ::Result<bool> {
134        let b = try!(self.read_i8());
135        match b {
136            0 => Ok(false),
137            _ => Ok(true),
138        }
139    }
140
141    fn read_i8(&mut self) -> ::Result<i8> {
142        self.transport.borrow_mut().read_i8().map_err(From::from)
143    }
144
145    fn read_i16(&mut self) -> ::Result<i16> {
146        self.transport.borrow_mut().read_i16::<BigEndian>().map_err(From::from)
147    }
148
149    fn read_i32(&mut self) -> ::Result<i32> {
150        self.transport.borrow_mut().read_i32::<BigEndian>().map_err(From::from)
151    }
152
153    fn read_i64(&mut self) -> ::Result<i64> {
154        self.transport.borrow_mut().read_i64::<BigEndian>().map_err(From::from)
155    }
156
157    fn read_double(&mut self) -> ::Result<f64> {
158        self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from)
159    }
160
161    fn read_string(&mut self) -> ::Result<String> {
162        let bytes = try!(self.read_bytes());
163        String::from_utf8(bytes).map_err(From::from)
164    }
165
166    fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
167        let element_type: TType = try!(self.read_byte().and_then(field_type_from_u8));
168        let size = try!(self.read_i32());
169        Ok(TListIdentifier::new(element_type, size))
170    }
171
172    fn read_list_end(&mut self) -> ::Result<()> {
173        Ok(())
174    }
175
176    fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
177        let element_type: TType = try!(self.read_byte().and_then(field_type_from_u8));
178        let size = try!(self.read_i32());
179        Ok(TSetIdentifier::new(element_type, size))
180    }
181
182    fn read_set_end(&mut self) -> ::Result<()> {
183        Ok(())
184    }
185
186    fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
187        let key_type: TType = try!(self.read_byte().and_then(field_type_from_u8));
188        let value_type: TType = try!(self.read_byte().and_then(field_type_from_u8));
189        let size = try!(self.read_i32());
190        Ok(TMapIdentifier::new(key_type, value_type, size))
191    }
192
193    fn read_map_end(&mut self) -> ::Result<()> {
194        Ok(())
195    }
196
197    //
198    // utility
199    //
200
201    fn read_byte(&mut self) -> ::Result<u8> {
202        self.transport.borrow_mut().read_u8().map_err(From::from)
203    }
204}
205
206/// Creates instances of `TBinaryInputProtocol` that use the strict Thrift
207/// binary encoding.
208pub struct TBinaryInputProtocolFactory;
209impl TInputProtocolFactory for TBinaryInputProtocolFactory {
210    fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol> {
211        Box::new(TBinaryInputProtocol::new(true, transport)) as Box<TInputProtocol>
212    }
213}
214
215/// Encodes messages using the Thrift simple binary encoding.
216pub struct TBinaryOutputProtocol {
217    strict: bool,
218    transport: Rc<RefCell<Box<TTransport>>>,
219}
220
221impl TBinaryOutputProtocol {
222    pub fn new(strict: bool, transport: Rc<RefCell<Box<TTransport>>>) -> TBinaryOutputProtocol {
223        TBinaryOutputProtocol { strict: strict, transport: transport }
224    }
225
226    fn write_transport(&mut self, buf: &[u8]) -> ::Result<()> {
227        self.transport.borrow_mut().write(buf).map(|_| ()).map_err(From::from)
228    }
229}
230
231impl TOutputProtocol for TBinaryOutputProtocol {
232    fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
233        if self.strict {
234            let message_type: u8 = identifier.message_type.into();
235            let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32);
236            try!(self.transport.borrow_mut().write_u32::<BigEndian>(header));
237            try!(self.write_string(&identifier.name));
238            self.write_i32(identifier.sequence_number)
239        } else {
240            try!(self.write_string(&identifier.name));
241            try!(self.write_byte(identifier.message_type.into()));
242            self.write_i32(identifier.sequence_number)
243        }
244    }
245
246    fn write_message_end(&mut self) -> ::Result<()> {
247        Ok(())
248    }
249
250    fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> {
251        Ok(())
252    }
253
254    fn write_struct_end(&mut self) -> ::Result<()> {
255        Ok(())
256    }
257
258    fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
259        if identifier.id.is_none() && identifier.field_type != TType::Stop {
260            return Err(
261                ::Error::Protocol(
262                    ProtocolError {
263                        kind: ProtocolErrorKind::Unknown,
264                        message: format!("cannot write identifier {:?} without sequence number", &identifier),
265                    }
266                )
267            )
268        }
269
270        try!(self.write_byte(field_type_to_u8(identifier.field_type)));
271        if let Some(id) = identifier.id {
272            self.write_i16(id)
273        } else {
274            Ok(())
275        }
276    }
277
278    fn write_field_end(&mut self) -> ::Result<()> {
279        Ok(())
280    }
281
282    fn write_field_stop(&mut self) -> ::Result<()> {
283        self.write_byte(field_type_to_u8(TType::Stop))
284    }
285
286    fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
287        try!(self.write_i32(b.len() as i32));
288        self.write_transport(b)
289    }
290
291    fn write_bool(&mut self, b: bool) -> ::Result<()> {
292        if b {
293            self.write_i8(1)
294        } else {
295            self.write_i8(0)
296        }
297    }
298
299    fn write_i8(&mut self, i: i8) -> ::Result<()> {
300        self.transport.borrow_mut().write_i8(i).map_err(From::from)
301    }
302
303    fn write_i16(&mut self, i: i16) -> ::Result<()> {
304        self.transport.borrow_mut().write_i16::<BigEndian>(i).map_err(From::from)
305    }
306
307    fn write_i32(&mut self, i: i32) -> ::Result<()> {
308        self.transport.borrow_mut().write_i32::<BigEndian>(i).map_err(From::from)
309    }
310
311    fn write_i64(&mut self, i: i64) -> ::Result<()> {
312        self.transport.borrow_mut().write_i64::<BigEndian>(i).map_err(From::from)
313    }
314
315    fn write_double(&mut self, d: f64) -> ::Result<()> {
316        self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from)
317    }
318
319    fn write_string(&mut self, s: &str) -> ::Result<()> {
320        self.write_bytes(s.as_bytes())
321    }
322
323    fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
324        try!(self.write_byte(field_type_to_u8(identifier.element_type)));
325        self.write_i32(identifier.size)
326    }
327
328    fn write_list_end(&mut self) -> ::Result<()> {
329        Ok(())
330    }
331
332    fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
333        try!(self.write_byte(field_type_to_u8(identifier.element_type)));
334        self.write_i32(identifier.size)
335    }
336
337    fn write_set_end(&mut self) -> ::Result<()> {
338        Ok(())
339    }
340
341    fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
342        let key_type = identifier.key_type.expect("map identifier to write should contain key type");
343        try!(self.write_byte(field_type_to_u8(key_type)));
344        let val_type = identifier.value_type.expect("map identifier to write should contain value type");
345        try!(self.write_byte(field_type_to_u8(val_type)));
346        self.write_i32(identifier.size)
347    }
348
349    fn write_map_end(&mut self) -> ::Result<()> {
350        Ok(())
351    }
352
353    fn flush(&mut self) -> ::Result<()> {
354        self.transport.borrow_mut().flush().map_err(From::from)
355    }
356
357    //
358    // utility
359    //
360
361    fn write_byte(&mut self, b: u8) -> ::Result<()> {
362        self.transport.borrow_mut().write_u8(b).map_err(From::from)
363    }
364}
365
366/// Creates instances of `TBinaryOutputProtocol` that use the strict Thrift
367/// binary encoding.
368pub struct TBinaryOutputProtocolFactory;
369impl TOutputProtocolFactory for TBinaryOutputProtocolFactory {
370    fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> {
371        Box::new(TBinaryOutputProtocol::new(true, transport)) as Box<TOutputProtocol>
372    }
373}
374
375fn field_type_to_u8(field_type: TType) -> u8 {
376    match field_type {
377        TType::Stop => 0x00,
378        TType::Void => 0x01,
379        TType::Bool => 0x02,
380        TType::I08 => 0x03, // equivalent to TType::Byte
381        TType::Double => 0x04,
382        TType::I16 => 0x06,
383        TType::I32 => 0x08,
384        TType::I64 => 0x0A,
385        TType::String => 0x0B,
386        TType::Utf7 => 0x0B,
387        TType::Struct => 0x0C,
388        TType::Map => 0x0D,
389        TType::Set => 0x0E,
390        TType::List => 0x0F,
391        TType::Utf8 => 0x10,
392        TType::Utf16 => 0x11,
393    }
394}
395
396fn field_type_from_u8(b: u8) -> ::Result<TType> {
397    match b {
398        0x00 => Ok(TType::Stop),
399        0x01 => Ok(TType::Void),
400        0x02 => Ok(TType::Bool),
401        0x03 => Ok(TType::I08), // Equivalent to TType::Byte
402        0x04 => Ok(TType::Double),
403        0x06 => Ok(TType::I16),
404        0x08 => Ok(TType::I32),
405        0x0A => Ok(TType::I64),
406        0x0B => Ok(TType::String), // technically, also a UTF7, but we'll treat it as string
407        0x0C => Ok(TType::Struct),
408        0x0D => Ok(TType::Map),
409        0x0E => Ok(TType::Set),
410        0x0F => Ok(TType::List),
411        0x10 => Ok(TType::Utf8),
412        0x11 => Ok(TType::Utf16),
413        unkn => Err(
414            ::Error::Protocol(
415                ProtocolError {
416                    kind: ProtocolErrorKind::InvalidData,
417                    message: format!("cannot convert {} to TType", unkn),
418                }
419            )
420        )
421    }
422}
423
424#[cfg(test)]
425mod tests {
426//
427//    use std::rc::Rc;
428//    use std::cell::RefCell;
429//
430//    use super::*;
431//    use ::protocol::{TMessageIdentifier, TMessageType, TProtocol};
432//    use ::transport::{TPassThruTransport, TTransport};
433//    use ::transport::mem::TBufferTransport;
434//
435//    macro_rules! test_objects {
436//        () => (
437//            {
438//                let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40))));
439//                let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() });
440//                let proto = TBinaryProtocol { strict: true, transport: Rc::new(RefCell::new(inner)) };
441//                (mem, proto)
442//            }
443//        );
444//    }
445//
446//    #[test]
447//    fn must_round_trip_strict_service_call_message_header() {
448//        let (trans, mut proto) = test_objects!();
449//
450//        let sent_ident = TMessageIdentifier { name: "test".to_owned(), message_type: TMessageType::Call, sequence_number: 1 };
451//        assert!(proto.write_message_begin(&sent_ident).is_ok());
452//
453//        let buf = {
454//            let m = trans.borrow();
455//            let written = m.write_buffer();
456//            let mut b = Vec::with_capacity(written.len());
457//            b.extend_from_slice(&written);
458//            b
459//        };
460//
461//        let bytes_copied = trans.borrow_mut().set_readable_bytes(&buf);
462//        assert_eq!(bytes_copied, buf.len());
463//
464//        let received_ident_result = proto.read_message_begin();
465//        assert!(received_ident_result.is_ok());
466//        assert_eq!(received_ident_result.unwrap(), sent_ident);
467//    }
468//
469//    #[test]
470//    fn must_write_message_end() {
471//        let (trans, mut proto) = test_objects!();
472//        assert!(proto.write_message_end().is_ok());
473//        assert_eq!(trans.borrow().write_buffer().len(), 0);
474//    }
475}