pub mod client;
pub mod server;
use std::io::{self, BufRead, ErrorKind, Read};
use std::sync::Arc;
use std::{cmp, mem, ptr, slice, usize};
use crate::cq::CompletionQueue;
use crate::grpc_sys::{
self, GrpcBatchContext, GrpcByteBufferReader, GrpcCall, GrpcCallStatus, GrpcSlice,
};
#[cfg(feature = "prost-codec")]
use bytes::Buf;
use futures::{Async, Future, Poll};
use libc::c_void;
use crate::codec::{DeserializeFn, Marshaller, SerializeFn};
use crate::error::{Error, Result};
use crate::task::{self, BatchFuture, BatchType, CallTag, SpinLock};
pub use crate::grpc_sys::GrpcStatusCode as RpcStatusCode;
impl<'a> From<&'a mut GrpcByteBuffer> for GrpcByteBufferReader {
fn from(src: &'a mut GrpcByteBuffer) -> Self {
let mut reader;
unsafe {
reader = mem::zeroed();
let init_result = grpc_sys::grpc_byte_buffer_reader_init(&mut reader, src.raw);
assert_eq!(init_result, 1);
}
reader
}
}
pub struct GrpcByteBuffer {
pub raw: *mut grpc_sys::GrpcByteBuffer,
}
impl Default for GrpcByteBuffer {
fn default() -> Self {
unsafe {
GrpcByteBuffer {
raw: grpc_sys::grpc_raw_byte_buffer_create(ptr::null_mut(), 0),
}
}
}
}
impl<'a> From<&'a mut [GrpcSlice]> for GrpcByteBuffer {
fn from(slice: &'a mut [GrpcSlice]) -> Self {
unsafe {
GrpcByteBuffer {
raw: grpc_sys::grpc_raw_byte_buffer_create(slice.as_mut_ptr(), slice.len()),
}
}
}
}
impl Clone for GrpcByteBuffer {
fn clone(&self) -> Self {
unsafe {
GrpcByteBuffer {
raw: grpc_sys::grpc_byte_buffer_copy(self.raw),
}
}
}
}
impl Drop for GrpcByteBuffer {
fn drop(&mut self) {
unsafe { grpc_sys::grpc_byte_buffer_destroy(self.raw) }
}
}
#[derive(Clone, Copy)]
pub enum MethodType {
Unary,
ClientStreaming,
ServerStreaming,
Duplex,
}
pub struct Method<Req, Resp> {
pub ty: MethodType,
pub name: &'static str,
pub req_mar: Marshaller<Req>,
pub resp_mar: Marshaller<Resp>,
}
impl<Req, Resp> Method<Req, Resp> {
#[inline]
pub fn req_ser(&self) -> SerializeFn<Req> {
self.req_mar.ser
}
#[inline]
pub fn req_de(&self) -> DeserializeFn<Req> {
self.req_mar.de
}
#[inline]
pub fn resp_ser(&self) -> SerializeFn<Resp> {
self.resp_mar.ser
}
#[inline]
pub fn resp_de(&self) -> DeserializeFn<Resp> {
self.resp_mar.de
}
}
#[derive(Debug, Clone)]
pub struct RpcStatus {
pub status: RpcStatusCode,
pub details: Option<String>,
}
impl RpcStatus {
pub fn new(status: RpcStatusCode, details: Option<String>) -> RpcStatus {
RpcStatus { status, details }
}
pub fn ok() -> RpcStatus {
RpcStatus::new(RpcStatusCode::Ok, None)
}
}
pub struct MessageReader {
_buf: GrpcByteBuffer,
reader: GrpcByteBufferReader,
buffer_slice: GrpcSlice,
buffer_offset: usize,
length: usize,
}
impl MessageReader {
#[inline]
pub fn pending_bytes_count(&self) -> usize {
self.length
}
}
unsafe impl Sync for MessageReader {}
unsafe impl Send for MessageReader {}
impl Read for MessageReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let amt = {
let bytes = self.fill_buf()?;
if bytes.is_empty() {
return Ok(0);
}
let amt = cmp::min(buf.len(), bytes.len());
buf[..amt].copy_from_slice(&bytes[..amt]);
amt
};
self.consume(amt);
Ok(amt)
}
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
if self.length == 0 {
return Ok(0);
}
buf.reserve(self.length);
let start = buf.len();
let mut len = start;
unsafe {
buf.set_len(start + self.length);
}
let ret = loop {
match self.read(&mut buf[len..]) {
Ok(0) => break Ok(len - start),
Ok(n) => len += n,
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => break Err(e),
}
};
unsafe {
buf.set_len(len);
}
ret
}
}
impl BufRead for MessageReader {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.pending_bytes_count() == 0 {
return Ok(&[]);
}
let buffer_len = self.buffer_slice.len();
if buffer_len == 0 || self.buffer_offset == buffer_len {
self.buffer_slice = self.reader.next_slice();
self.buffer_offset = 0;
}
debug_assert!(self.buffer_offset <= buffer_len);
Ok(self.buffer_slice.range_from(self.buffer_offset))
}
fn consume(&mut self, amt: usize) {
self.length -= amt;
self.buffer_offset += amt;
}
}
impl Drop for MessageReader {
fn drop(&mut self) {
unsafe {
grpc_sys::grpc_byte_buffer_reader_destroy(&mut self.reader);
}
}
}
#[cfg(feature = "prost-codec")]
impl Buf for MessageReader {
fn remaining(&self) -> usize {
self.pending_bytes_count()
}
fn bytes(&self) -> &[u8] {
if self.buffer_slice.is_empty() {
return &[];
}
debug_assert!(self.buffer_offset <= self.buffer_slice.len());
self.buffer_slice.range_from(self.buffer_offset)
}
fn advance(&mut self, mut cnt: usize) {
let mut remaining = self.buffer_slice.len() - self.buffer_offset;
while remaining <= cnt {
self.consume(remaining);
if self.pending_bytes_count() == 0 {
return;
}
cnt -= remaining;
self.buffer_slice = self.reader.next_slice();
self.buffer_offset = 0;
remaining = self.buffer_slice.len();
}
self.consume(cnt);
}
}
pub struct BatchContext {
ctx: *mut GrpcBatchContext,
}
impl BatchContext {
pub fn new() -> BatchContext {
BatchContext {
ctx: unsafe { grpc_sys::grpcwrap_batch_context_create() },
}
}
pub fn as_ptr(&self) -> *mut GrpcBatchContext {
self.ctx
}
pub fn take_recv_message(&self) -> Option<GrpcByteBuffer> {
let ptr = unsafe { grpc_sys::grpcwrap_batch_context_take_recv_message(self.ctx) };
if ptr.is_null() {
None
} else {
Some(GrpcByteBuffer { raw: ptr })
}
}
pub fn rpc_status(&self) -> RpcStatus {
let status =
unsafe { grpc_sys::grpcwrap_batch_context_recv_status_on_client_status(self.ctx) };
let details = if status == RpcStatusCode::Ok {
None
} else {
unsafe {
let mut details_len = 0;
let details_ptr = grpc_sys::grpcwrap_batch_context_recv_status_on_client_details(
self.ctx,
&mut details_len,
);
let details_slice = slice::from_raw_parts(details_ptr as *const _, details_len);
Some(String::from_utf8_lossy(details_slice).into_owned())
}
};
RpcStatus { status, details }
}
pub fn recv_message(&mut self) -> Option<MessageReader> {
let mut buf = self.take_recv_message()?;
let reader = GrpcByteBufferReader::from(&mut buf);
let length = reader.len();
Some(MessageReader {
_buf: buf,
reader,
buffer_slice: Default::default(),
buffer_offset: 0,
length,
})
}
}
impl Drop for BatchContext {
fn drop(&mut self) {
unsafe { grpc_sys::grpcwrap_batch_context_destroy(self.ctx) }
}
}
#[inline]
fn box_batch_tag(tag: CallTag) -> (*mut GrpcBatchContext, *mut c_void) {
let tag_box = Box::new(tag);
(
tag_box.batch_ctx().unwrap().as_ptr(),
Box::into_raw(tag_box) as _,
)
}
fn check_run<F>(bt: BatchType, f: F) -> BatchFuture
where
F: FnOnce(*mut GrpcBatchContext, *mut c_void) -> GrpcCallStatus,
{
let (cq_f, tag) = CallTag::batch_pair(bt);
let (batch_ptr, tag_ptr) = box_batch_tag(tag);
let code = f(batch_ptr, tag_ptr);
if code != GrpcCallStatus::Ok {
unsafe {
Box::from_raw(tag_ptr);
}
panic!("create call fail: {:?}", code);
}
cq_f
}
pub struct Call {
pub call: *mut GrpcCall,
pub cq: CompletionQueue,
}
unsafe impl Send for Call {}
impl Call {
pub unsafe fn from_raw(call: *mut grpc_sys::GrpcCall, cq: CompletionQueue) -> Call {
assert!(!call.is_null());
Call { call, cq }
}
pub fn start_send_message(
&mut self,
msg: &[u8],
write_flags: u32,
initial_meta: bool,
) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let i = if initial_meta { 1 } else { 0 };
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_send_message(
self.call,
ctx,
msg.as_ptr() as _,
msg.len(),
write_flags,
i,
tag,
)
});
Ok(f)
}
pub fn start_send_close_client(&mut self) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Finish, |_, tag| unsafe {
grpc_sys::grpcwrap_call_send_close_from_client(self.call, tag)
});
Ok(f)
}
pub fn start_recv_message(&mut self) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Read, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_recv_message(self.call, ctx, tag)
});
Ok(f)
}
pub fn start_server_side(&mut self) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_serverside(self.call, ctx, tag)
});
Ok(f)
}
pub fn start_send_status_from_server(
&mut self,
status: &RpcStatus,
send_empty_metadata: bool,
payload: &Option<Vec<u8>>,
write_flags: u32,
) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let send_empty_metadata = if send_empty_metadata { 1 } else { 0 };
let (payload_ptr, payload_len) = payload
.as_ref()
.map_or((ptr::null(), 0), |b| (b.as_ptr(), b.len()));
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let details_ptr = status
.details
.as_ref()
.map_or_else(ptr::null, |s| s.as_ptr() as _);
let details_len = status.details.as_ref().map_or(0, String::len);
grpc_sys::grpcwrap_call_send_status_from_server(
self.call,
ctx,
status.status,
details_ptr,
details_len,
ptr::null_mut(),
send_empty_metadata,
payload_ptr as _,
payload_len,
write_flags,
tag,
)
});
Ok(f)
}
pub fn abort(self, status: &RpcStatus) {
match self.cq.borrow() {
Err(Error::QueueShutdown) => return,
Err(e) => panic!("unexpected error when aborting call: {:?}", e),
_ => {}
}
let call_ptr = self.call;
let tag = CallTag::abort(self);
let (batch_ptr, tag_ptr) = box_batch_tag(tag);
let code = unsafe {
let details_ptr = status
.details
.as_ref()
.map_or_else(ptr::null, |s| s.as_ptr() as _);
let details_len = status.details.as_ref().map_or(0, String::len);
grpc_sys::grpcwrap_call_send_status_from_server(
call_ptr,
batch_ptr,
status.status,
details_ptr,
details_len,
ptr::null_mut(),
1,
ptr::null(),
0,
0,
tag_ptr as *mut c_void,
)
};
if code != GrpcCallStatus::Ok {
unsafe {
Box::from_raw(tag_ptr);
}
panic!("create call fail: {:?}", code);
}
}
fn cancel(&self) {
match self.cq.borrow() {
Err(Error::QueueShutdown) => return,
Err(e) => panic!("unexpected error when canceling call: {:?}", e),
_ => {}
}
unsafe {
grpc_sys::grpc_call_cancel(self.call, ptr::null_mut());
}
}
}
impl Drop for Call {
fn drop(&mut self) {
unsafe { grpc_sys::grpc_call_unref(self.call) }
}
}
struct ShareCall {
call: Call,
close_f: BatchFuture,
finished: bool,
status: Option<RpcStatus>,
}
impl ShareCall {
fn new(call: Call, close_f: BatchFuture) -> ShareCall {
ShareCall {
call,
close_f,
finished: false,
status: None,
}
}
fn poll_finish(&mut self) -> Poll<Option<MessageReader>, Error> {
let res = match self.close_f.poll() {
Err(Error::RpcFailure(status)) => {
self.status = Some(status.clone());
Err(Error::RpcFailure(status))
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(msg)) => {
self.status = Some(RpcStatus::ok());
Ok(Async::Ready(msg))
}
res => res,
};
self.finished = true;
res
}
fn check_alive(&mut self) -> Result<()> {
if self.finished {
return Err(Error::RpcFinished(self.status.clone()));
}
task::check_alive(&self.close_f)
}
}
trait ShareCallHolder {
fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R;
}
impl ShareCallHolder for ShareCall {
fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R {
f(self)
}
}
impl ShareCallHolder for Arc<SpinLock<ShareCall>> {
fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R {
let mut call = self.lock();
f(&mut call)
}
}
struct StreamingBase {
close_f: Option<BatchFuture>,
msg_f: Option<BatchFuture>,
read_done: bool,
}
impl StreamingBase {
fn new(close_f: Option<BatchFuture>) -> StreamingBase {
StreamingBase {
close_f,
msg_f: None,
read_done: false,
}
}
fn poll<C: ShareCallHolder>(
&mut self,
call: &mut C,
skip_finish_check: bool,
) -> Poll<Option<MessageReader>, Error> {
if !skip_finish_check {
let mut finished = false;
if let Some(ref mut close_f) = self.close_f {
match close_f.poll() {
Ok(Async::Ready(_)) => {
finished = true;
}
Err(e) => return Err(e),
Ok(Async::NotReady) => {}
}
}
if finished {
self.close_f.take();
}
}
let mut bytes = None;
if !self.read_done {
if let Some(ref mut msg_f) = self.msg_f {
bytes = try_ready!(msg_f.poll());
if bytes.is_none() {
self.read_done = true;
}
}
}
if self.read_done {
if self.close_f.is_none() {
return Ok(Async::Ready(bytes));
}
return Ok(Async::NotReady);
}
self.msg_f.take();
let msg_f = call.call(|c| c.call.start_recv_message())?;
self.msg_f = Some(msg_f);
if bytes.is_none() {
self.poll(call, true)
} else {
Ok(Async::Ready(bytes))
}
}
fn on_drop<C: ShareCallHolder>(&self, call: &mut C) {
if !self.read_done || self.close_f.is_some() {
call.call(|c| c.call.cancel());
}
}
}
#[derive(Default, Clone, Copy)]
pub struct WriteFlags {
flags: u32,
}
impl WriteFlags {
pub fn buffer_hint(mut self, need_buffered: bool) -> WriteFlags {
client::change_flag(
&mut self.flags,
grpc_sys::GRPC_WRITE_BUFFER_HINT,
need_buffered,
);
self
}
pub fn force_no_compress(mut self, no_compress: bool) -> WriteFlags {
client::change_flag(
&mut self.flags,
grpc_sys::GRPC_WRITE_NO_COMPRESS,
no_compress,
);
self
}
pub fn get_buffer_hint(self) -> bool {
(self.flags & grpc_sys::GRPC_WRITE_BUFFER_HINT) != 0
}
pub fn get_force_no_compress(self) -> bool {
(self.flags & grpc_sys::GRPC_WRITE_NO_COMPRESS) != 0
}
}
struct SinkBase {
batch_f: Option<BatchFuture>,
buf: Vec<u8>,
send_metadata: bool,
}
impl SinkBase {
fn new(send_metadata: bool) -> SinkBase {
SinkBase {
batch_f: None,
buf: Vec::new(),
send_metadata,
}
}
fn start_send<T, C: ShareCallHolder>(
&mut self,
call: &mut C,
t: &T,
mut flags: WriteFlags,
ser: SerializeFn<T>,
) -> Result<bool> {
if self.batch_f.is_some() {
self.poll_complete()?;
if self.batch_f.is_some() {
return Ok(false);
}
}
self.buf.clear();
ser(t, &mut self.buf);
if flags.get_buffer_hint() && self.send_metadata {
flags = flags.buffer_hint(false);
}
let write_f = call.call(|c| {
c.call
.start_send_message(&self.buf, flags.flags, self.send_metadata)
})?;
self.batch_f = Some(write_f);
self.send_metadata = false;
Ok(true)
}
fn poll_complete(&mut self) -> Poll<(), Error> {
if let Some(ref mut batch_f) = self.batch_f {
try_ready!(batch_f.poll());
}
self.batch_f.take();
Ok(Async::Ready(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_message_reader(source: &[u8], n_slice: usize) -> MessageReader {
let mut slices = vec![From::from(source); n_slice];
let mut buf = GrpcByteBuffer::from(slices.as_mut_slice());
let reader = GrpcByteBufferReader::from(&mut buf);
let length = reader.len();
MessageReader {
_buf: buf,
reader,
buffer_slice: Default::default(),
buffer_offset: 0,
length,
}
}
#[test]
fn test_typo_len_offset() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
const HALF_SIZE: usize = 4;
let mut reader = make_message_reader(&data, 1);
assert_eq!(reader.pending_bytes_count(), data.len());
let mut buf = [0; HALF_SIZE];
reader.read(&mut buf).unwrap();
assert_eq!(data[..HALF_SIZE], buf);
reader.read(&mut buf).unwrap();
assert_eq!(data[HALF_SIZE..], buf);
}
#[test]
fn test_message_reader() {
for len in 0..1024 + 1 {
for n_slice in 1..4 {
let source = vec![len as u8; len];
let expect = vec![len as u8; len * n_slice];
let mut reader = make_message_reader(&source, n_slice);
let mut dest = [0; 7];
let amt = reader.read(&mut dest).unwrap();
assert_eq!(
dest[..amt],
expect[..amt],
"len: {}, nslice: {}",
len,
n_slice
);
let mut box_reader = Box::new(reader);
let amt = box_reader.read(&mut dest).unwrap();
assert_eq!(
dest[..amt],
expect[..amt],
"len: {}, nslice: {}",
len,
n_slice
);
let mut reader = make_message_reader(&source, n_slice);
let mut dest = vec![];
reader.read_to_end(&mut dest).unwrap();
assert_eq!(dest, expect, "len: {}, nslice: {}", len, n_slice);
assert_eq!(0, reader.pending_bytes_count());
assert_eq!(0, reader.read(&mut [1]).unwrap())
}
}
}
#[cfg(feature = "prost-codec")]
#[test]
fn test_buf_impl() {
for len in 0..1024 + 1 {
for n_slice in 1..4 {
let source = vec![len as u8; len];
let mut reader = make_message_reader(&source, n_slice);
let mut remaining = len * n_slice;
let mut count = 100;
while reader.remaining() > 0 {
assert_eq!(remaining, reader.remaining());
let bytes = Buf::bytes(&reader);
bytes.iter().for_each(|b| assert_eq!(*b, len as u8));
let mut read = bytes.len();
if read > 5 && len % 2 == 0 {
read -= 5;
}
reader.advance(read);
remaining -= read;
count -= 1;
assert!(count > 0);
}
assert_eq!(0, remaining);
assert_eq!(0, reader.remaining());
}
}
}
}