hadoop_common/ipc/
protobuf_rpc_engine2.rs

1use super::{AlignmentContext, Client, ConnectionId, RpcEngine, RpcProtocol, RPC};
2use crate::{
3    conf::Configuration, io::retry::RetryPolicy, ipc::RpcKind, security::UserGroupInformation,
4};
5use anyhow::Error;
6use atomic::Atomic;
7use hadoop_proto::hadoop::common::{
8    rpc_response_header_proto::RpcStatusProto, RequestHeaderProto, RpcResponseHeaderProto,
9};
10use prost::Message;
11use std::{marker::PhantomData, net::SocketAddr, rc::Rc, sync::Arc};
12
13pub struct ProtobufRpcEngine2;
14
15impl RpcEngine for ProtobufRpcEngine2 {
16    fn get_proxy<T: RpcProtocol>(
17        &self,
18        addr: &SocketAddr,
19        ticket: &UserGroupInformation,
20        conf: &Configuration,
21        rpc_timeout: i32,
22        connection_retry_policy: Option<Rc<dyn RetryPolicy>>,
23        fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
24        alignment_context: Option<Rc<dyn AlignmentContext>>,
25    ) -> anyhow::Result<T> {
26        Ok(T::from(Invoker::from_socket_addr(
27            addr,
28            ticket,
29            conf,
30            rpc_timeout,
31            connection_retry_policy,
32            fallback_to_simple_auth,
33            alignment_context,
34        )?))
35    }
36}
37
38pub struct Invoker<T: RpcProtocol> {
39    remote_id: Rc<ConnectionId>,
40    client: Client,
41    client_protocol_version: u64,
42    protocol_name: String,
43    fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
44    alignment_context: Option<Rc<dyn AlignmentContext>>,
45    phantom: PhantomData<T>,
46}
47
48impl<T: RpcProtocol> Invoker<T> {
49    pub fn from_socket_addr(
50        addr: &SocketAddr,
51        ticket: &UserGroupInformation,
52        conf: &Configuration,
53        rpc_timeout: i32,
54        connection_retry_policy: Option<Rc<dyn RetryPolicy>>,
55        fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
56        alignment_context: Option<Rc<dyn AlignmentContext>>,
57    ) -> anyhow::Result<Self> {
58        let connection_id = Rc::new(ConnectionId::get_connection_id(
59            addr,
60            ticket,
61            rpc_timeout,
62            connection_retry_policy,
63            conf,
64        )?);
65        Ok(Self::from_connection_id(
66            connection_id,
67            conf,
68            fallback_to_simple_auth,
69            alignment_context,
70        )?)
71    }
72
73    /// This constructor takes a connection_id, instead of creating a new one.
74    pub fn from_connection_id(
75        conn_id: Rc<ConnectionId>,
76        conf: &Configuration,
77        fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
78        alignment_context: Option<Rc<dyn AlignmentContext>>,
79    ) -> anyhow::Result<Self> {
80        // TODO: construct & cache client (or consider client singleton)
81
82        // TODO: value_class
83
84        Ok(Self {
85            remote_id: conn_id,
86            client: Client::new("value_class", conf)?,
87            client_protocol_version: RPC::get_protocol_version::<T>(),
88            protocol_name: RPC::get_protocol_name::<T>().to_owned(),
89            fallback_to_simple_auth,
90            alignment_context,
91            phantom: PhantomData,
92        })
93    }
94
95    fn construct_rpc_request_header(&self, method: &str) -> RequestHeaderProto {
96        RequestHeaderProto {
97            method_name: method.to_owned(),
98            declaring_class_protocol_name: self.protocol_name.to_owned(),
99            client_protocol_version: self.client_protocol_version,
100        }
101    }
102
103    /// This is the client side invoker of RPC method.
104    pub fn invoke<M: Default + Message>(
105        &self,
106        method: &str,
107        the_request: &impl Message,
108    ) -> anyhow::Result<M> {
109        let val = self.client.call::<T>(
110            &RpcKind::RpcProtocolBuffer,
111            Rc::new(self.construct_rpc_request(method, the_request)),
112            Rc::clone(&self.remote_id),
113            RPC::RPC_SERVICE_CLASS_DEFAULT,
114            self.fallback_to_simple_auth.as_ref().map(Arc::clone),
115            self.alignment_context.as_ref().map(Rc::clone),
116        )?;
117
118        // TODO: support asynchronous mode
119
120        self.get_return_message(method, &val)
121    }
122
123    fn construct_rpc_request(&self, method: &str, the_request: &impl Message) -> Vec<u8> {
124        let rpc_request_header = self.construct_rpc_request_header(method);
125        let mut output = rpc_request_header.encode_length_delimited_to_vec();
126        let mut payload = the_request.encode_length_delimited_to_vec();
127        output.append(&mut payload);
128        output
129    }
130
131    fn get_return_message<M: Default + Message>(
132        &self,
133        _method: &str,
134        buf: &Vec<u8>,
135    ) -> anyhow::Result<M> {
136        // TODO: use Writable
137
138        let mut buffer = &buf[..];
139        let header: RpcResponseHeaderProto = Message::decode_length_delimited(buffer)?;
140        let status = header.status();
141        if status == RpcStatusProto::Success {
142            let header_len = header.encode_length_delimited_to_vec().len();
143            buffer = &buf[header_len..];
144            let res = M::decode_length_delimited(buffer)?;
145            return Ok(res);
146        }
147        Err(Error::msg(format!("{:#?}", header)))
148    }
149}