1use super::{AlignmentContext, Client, ConnectionId, RpcEngine, RpcProtocol, RPC};
2use crate::{
3 conf::Configuration, io::retry::RetryPolicy, ipc::RpcKind, security::UserGroupInformation,
4};
5use anyhow::Error;
6use atomic::Atomic;
7use hadoop_proto::hadoop::common::{
8 rpc_response_header_proto::RpcStatusProto, RequestHeaderProto, RpcResponseHeaderProto,
9};
10use prost::Message;
11use std::{marker::PhantomData, net::SocketAddr, rc::Rc, sync::Arc};
12
13pub struct ProtobufRpcEngine2;
14
15impl RpcEngine for ProtobufRpcEngine2 {
16 fn get_proxy<T: RpcProtocol>(
17 &self,
18 addr: &SocketAddr,
19 ticket: &UserGroupInformation,
20 conf: &Configuration,
21 rpc_timeout: i32,
22 connection_retry_policy: Option<Rc<dyn RetryPolicy>>,
23 fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
24 alignment_context: Option<Rc<dyn AlignmentContext>>,
25 ) -> anyhow::Result<T> {
26 Ok(T::from(Invoker::from_socket_addr(
27 addr,
28 ticket,
29 conf,
30 rpc_timeout,
31 connection_retry_policy,
32 fallback_to_simple_auth,
33 alignment_context,
34 )?))
35 }
36}
37
38pub struct Invoker<T: RpcProtocol> {
39 remote_id: Rc<ConnectionId>,
40 client: Client,
41 client_protocol_version: u64,
42 protocol_name: String,
43 fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
44 alignment_context: Option<Rc<dyn AlignmentContext>>,
45 phantom: PhantomData<T>,
46}
47
48impl<T: RpcProtocol> Invoker<T> {
49 pub fn from_socket_addr(
50 addr: &SocketAddr,
51 ticket: &UserGroupInformation,
52 conf: &Configuration,
53 rpc_timeout: i32,
54 connection_retry_policy: Option<Rc<dyn RetryPolicy>>,
55 fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
56 alignment_context: Option<Rc<dyn AlignmentContext>>,
57 ) -> anyhow::Result<Self> {
58 let connection_id = Rc::new(ConnectionId::get_connection_id(
59 addr,
60 ticket,
61 rpc_timeout,
62 connection_retry_policy,
63 conf,
64 )?);
65 Ok(Self::from_connection_id(
66 connection_id,
67 conf,
68 fallback_to_simple_auth,
69 alignment_context,
70 )?)
71 }
72
73 pub fn from_connection_id(
75 conn_id: Rc<ConnectionId>,
76 conf: &Configuration,
77 fallback_to_simple_auth: Option<Arc<Atomic<bool>>>,
78 alignment_context: Option<Rc<dyn AlignmentContext>>,
79 ) -> anyhow::Result<Self> {
80 Ok(Self {
85 remote_id: conn_id,
86 client: Client::new("value_class", conf)?,
87 client_protocol_version: RPC::get_protocol_version::<T>(),
88 protocol_name: RPC::get_protocol_name::<T>().to_owned(),
89 fallback_to_simple_auth,
90 alignment_context,
91 phantom: PhantomData,
92 })
93 }
94
95 fn construct_rpc_request_header(&self, method: &str) -> RequestHeaderProto {
96 RequestHeaderProto {
97 method_name: method.to_owned(),
98 declaring_class_protocol_name: self.protocol_name.to_owned(),
99 client_protocol_version: self.client_protocol_version,
100 }
101 }
102
103 pub fn invoke<M: Default + Message>(
105 &self,
106 method: &str,
107 the_request: &impl Message,
108 ) -> anyhow::Result<M> {
109 let val = self.client.call::<T>(
110 &RpcKind::RpcProtocolBuffer,
111 Rc::new(self.construct_rpc_request(method, the_request)),
112 Rc::clone(&self.remote_id),
113 RPC::RPC_SERVICE_CLASS_DEFAULT,
114 self.fallback_to_simple_auth.as_ref().map(Arc::clone),
115 self.alignment_context.as_ref().map(Rc::clone),
116 )?;
117
118 self.get_return_message(method, &val)
121 }
122
123 fn construct_rpc_request(&self, method: &str, the_request: &impl Message) -> Vec<u8> {
124 let rpc_request_header = self.construct_rpc_request_header(method);
125 let mut output = rpc_request_header.encode_length_delimited_to_vec();
126 let mut payload = the_request.encode_length_delimited_to_vec();
127 output.append(&mut payload);
128 output
129 }
130
131 fn get_return_message<M: Default + Message>(
132 &self,
133 _method: &str,
134 buf: &Vec<u8>,
135 ) -> anyhow::Result<M> {
136 let mut buffer = &buf[..];
139 let header: RpcResponseHeaderProto = Message::decode_length_delimited(buffer)?;
140 let status = header.status();
141 if status == RpcStatusProto::Success {
142 let header_len = header.encode_length_delimited_to_vec().len();
143 buffer = &buf[header_len..];
144 let res = M::decode_length_delimited(buffer)?;
145 return Ok(res);
146 }
147 Err(Error::msg(format!("{:#?}", header)))
148 }
149}