1use crate::register::RegisterBuilder;
2use crate::route::Route;
3use crate::support::triple::{TripleExceptionWrapper, TripleRequestWrapper, TripleResponseWrapper};
4use bytes::{BufMut, BytesMut};
5use http::{HeaderValue, Request};
6use http_body_util::{BodyExt, Full};
7use hyper::client::conn::http2::SendRequest;
8use krpc_common::{KrpcMsg, RpcError};
9use prost::Message;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15pub struct KrpcClient {
16 route: Route,
17}
18
19impl KrpcClient {
20 pub fn build(register_builder: RegisterBuilder) -> KrpcClient {
21 let map = Arc::new(RwLock::new(HashMap::new()));
22 let register = register_builder.init(map.clone());
23 let cli = KrpcClient {
24 route: Route::new(map, register),
25 };
26 return cli;
27 }
28
29 pub async fn invoke<Res>(&self, msg: KrpcMsg) -> Result<Res, RpcError>
30 where
31 Res: Send + Sync + Serialize + for<'a> Deserialize<'a> + Default,
32 {
33 let mut sender: SendRequest<Full<bytes::Bytes>> = self
34 .route
35 .get_socket_sender(&msg.class_name, msg.version.as_deref())
36 .await
37 .map_err(|e| RpcError::Client(e.to_string()))?;
38 let buf = TripleRequestWrapper::get_buf(msg.req);
39 let mut builder = Request::builder()
40 .uri("/".to_owned() + &msg.class_name + "/" + &msg.method_name)
41 .header("content-type", "application/grpc+proto");
42 if let Some(version) = msg.version {
43 builder.headers_mut().unwrap().insert(
44 "tri-service-version",
45 HeaderValue::from_str(&version).unwrap(),
46 );
47 }
48 let req = builder
49 .body(Full::<bytes::Bytes>::from(buf))
50 .map_err(|e| RpcError::Client(e.to_string()))?;
51 let mut response = sender
52 .send_request(req)
53 .await
54 .map_err(|e| RpcError::Client(e.to_string()))?;
55 let mut res_body = BytesMut::new();
56 loop {
57 let res_frame = response
58 .frame()
59 .await
60 .map_or(Err(RpcError::Server("error frame 1".to_owned())), |e| Ok(e))?
61 .map_err(|e| RpcError::Client(e.to_string()))?;
62 if res_frame.is_trailers() {
63 let trailers = res_frame
64 .trailers_ref()
65 .map_or(Err(RpcError::Server("error frame 2".to_owned())), |e| Ok(e))?;
66 match trailers.get("grpc-status") {
67 Some(status) => match status.as_bytes() {
68 b"0" => {
69 let trip_res = TripleResponseWrapper::decode(&res_body.to_vec()[5..])
70 .map_err(|e| RpcError::Client(e.to_string()))?;
71 if trip_res.is_empty_body() {
72 return Err(RpcError::Null);
73 }
74 let res: Res = serde_json::from_slice(&trip_res.data)
75 .map_err(|e| RpcError::Client(e.to_string()))?;
76 return Ok(res);
77 }
78 else_status => {
79 if !res_body.is_empty() {
80 let trip_res: TripleExceptionWrapper =
81 TripleExceptionWrapper::decode(&res_body.to_vec()[5..])
82 .map_err(|e| RpcError::Client(e.to_string()))?;
83 let msg = String::from_utf8(trip_res.data).unwrap();
84 match else_status {
85 b"90" => return Err(RpcError::Client(msg)),
86 b"91" => return Err(RpcError::Method(msg)),
87 b"92" => return Err(RpcError::Null),
88 _ => return Err(RpcError::Server(msg)),
89 }
90 }
91 return Err(RpcError::Server(match trailers.get("grpc-message") {
92 Some(value) => {
93 "grpc-message=".to_owned()
94 + &String::from_utf8(value.as_bytes().to_vec()).unwrap()
95 }
96 None => {
97 "grpc-status=".to_owned()
98 + &String::from_utf8(else_status.to_vec()).unwrap()
99 }
100 }));
101 }
102 },
103 None => return Err(RpcError::Server("error frame 3".to_owned())),
104 }
105 } else {
106 let res_data = res_frame
107 .into_data()
108 .map_err(|_e| RpcError::Server("error frame 4".to_owned()))?;
109 let _ = res_body.put(res_data);
110 }
111 }
112 }
113}