nebula_client/v3/graph/
transport_response_handler.rs

1use std::io::{Cursor, Error as IoError, ErrorKind as IoErrorKind};
2
3use bytes::BytesMut;
4use fbthrift::{
5    binary_protocol::{BinaryProtocolDeserializer, BinaryProtocolSerializer},
6    ApplicationException, Deserialize, MessageType, ProtocolReader, ProtocolWriter, Serialize,
7};
8use fbthrift_transport_response_handler::ResponseHandler;
9use nebula_fbthrift_graph_v3::services::graph_service::{
10    AuthenticateExn, ExecuteExn, ExecuteJsonExn, SignoutExn,
11};
12
13#[derive(Clone)]
14pub struct GraphTransportResponseHandler;
15
16impl ResponseHandler for GraphTransportResponseHandler {
17    fn try_make_static_response_bytes(
18        &mut self,
19        _service_name: &'static [u8],
20        fn_name: &'static [u8],
21        request_bytes: &[u8],
22    ) -> Result<Option<Vec<u8>>, IoError> {
23        match fn_name {
24            b"GraphService.authenticate" => Ok(None),
25            b"GraphService.signout" => {
26                let mut des = BinaryProtocolDeserializer::new(Cursor::new(request_bytes));
27                let (name, message_type, seqid) = des
28                    .read_message_begin(|v| v.to_vec())
29                    .map_err(|err| IoError::new(IoErrorKind::Other, err))?;
30
31                if name != b"signout" {
32                    return Err(IoError::new(
33                        IoErrorKind::Other,
34                        format!("Unexpected name {name:?}"),
35                    ));
36                }
37
38                if message_type != MessageType::Call {
39                    return Err(IoError::new(
40                        IoErrorKind::Other,
41                        format!("Unexpected message type {message_type:?}"),
42                    ));
43                }
44
45                let buf = BytesMut::with_capacity(1024);
46                let mut ser = BinaryProtocolSerializer::<BytesMut>::with_buffer(buf);
47
48                ser.write_message_begin("signout", MessageType::Reply, seqid);
49                ser.write_message_end();
50
51                SignoutExn::Success(()).write(&mut ser);
52
53                let res_buf = ser.finish().to_vec();
54
55                Ok(Some(res_buf))
56            }
57            b"GraphService.execute" => Ok(None),
58            b"GraphService.executeJson" => Ok(None),
59            _ => Err(IoError::new(
60                IoErrorKind::Other,
61                format!("Unknown method {}", String::from_utf8_lossy(fn_name)),
62            )),
63        }
64    }
65
66    fn parse_response_bytes(&mut self, response_bytes: &[u8]) -> Result<Option<usize>, IoError> {
67        let mut des = BinaryProtocolDeserializer::new(Cursor::new(response_bytes));
68        let (name, message_type, _) = match des.read_message_begin(|v| v.to_vec()) {
69            Ok(v) => v,
70            Err(_) => return Ok(None),
71        };
72
73        match &name[..] {
74            b"authenticate" => {}
75            b"signout" => unreachable!(),
76            b"execute" => {}
77            b"executeJson" => {}
78            _ => return Ok(None),
79        };
80
81        match message_type {
82            MessageType::Reply => {
83                match &name[..] {
84                    b"authenticate" => {
85                        let _: AuthenticateExn = match Deserialize::read(&mut des) {
86                            Ok(v) => v,
87                            Err(_) => return Ok(None),
88                        };
89                    }
90                    b"execute" => {
91                        let _: ExecuteExn = match Deserialize::read(&mut des) {
92                            Ok(v) => v,
93                            Err(_) => return Ok(None),
94                        };
95                    }
96                    b"executeJson" => {
97                        let _: ExecuteJsonExn = match Deserialize::read(&mut des) {
98                            Ok(v) => v,
99                            Err(_) => return Ok(None),
100                        };
101                    }
102                    _ => unreachable!(),
103                };
104            }
105            MessageType::Exception => {
106                let _: ApplicationException = match Deserialize::read(&mut des) {
107                    Ok(v) => v,
108                    Err(_) => return Ok(None),
109                };
110            }
111            MessageType::Call | MessageType::Oneway | MessageType::InvalidMessageType => {}
112        }
113
114        match des.read_message_end() {
115            Ok(v) => v,
116            Err(_) => return Ok(None),
117        };
118
119        Ok(Some(des.into_inner().position() as usize))
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_try_make_static_response_bytes() -> Result<(), Box<dyn std::error::Error>> {
129        let mut handler = GraphTransportResponseHandler;
130
131        assert_eq!(
132            handler.try_make_static_response_bytes(
133                b"GraphService",
134                b"GraphService.authenticate",
135                b"FOO"
136            )?,
137            None
138        );
139        assert_eq!(
140            handler.try_make_static_response_bytes(
141                b"GraphService",
142                b"GraphService.execute",
143                b"FOO"
144            )?,
145            None
146        );
147        assert_eq!(
148            handler.try_make_static_response_bytes(
149                b"GraphService",
150                b"GraphService.executeJson",
151                b"FOO"
152            )?,
153            None
154        );
155        match handler.try_make_static_response_bytes(b"GraphService", b"GraphService.foo", b"FOO") {
156            Ok(_) => panic!(),
157            Err(err) => {
158                assert_eq!(err.kind(), IoErrorKind::Other);
159
160                assert_eq!(err.to_string(), "Unknown method GraphService.foo");
161            }
162        }
163
164        Ok(())
165    }
166
167    #[test]
168    fn test_try_make_static_response_bytes_with_signout() -> Result<(), Box<dyn std::error::Error>>
169    {
170        let mut handler = GraphTransportResponseHandler;
171
172        //
173        // Ref https://github.com/bk-rs/nebula-rs/blob/e500e6f93b0ffcd009038c2a51b41a6aa3488b18/nebula-fbthrift/nebula-fbthrift-graph-v2/src/lib.rs#L1346
174        //
175        let request = ::fbthrift::serialize!(::fbthrift::BinaryProtocol, |p| {
176            p.write_message_begin("signout", ::fbthrift::MessageType::Call, 0);
177
178            p.write_struct_begin("args");
179            p.write_field_begin("arg_sessionId", ::fbthrift::TType::I64, 1i16);
180            ::fbthrift::Serialize::write(&1, p);
181            p.write_field_end();
182            p.write_field_stop();
183            p.write_struct_end();
184
185            p.write_message_end();
186        });
187
188        match handler.try_make_static_response_bytes(
189            b"GraphService",
190            b"GraphService.signout",
191            &request[..],
192        ) {
193            Ok(Some(_)) => {}
194            _ => panic!(),
195        }
196
197        Ok(())
198    }
199}