dfhack_remote/
channel.rs

1//! Internal module describing the exchange flow
2//!
3//! Implements the flow of sending and receiving messages.
4//! This includes the custom workflow of binding RPC methods
5//! before being able to use them.
6
7use std::collections::HashMap;
8
9use crate::{
10    message::{self, Receive, Send},
11    Error,
12};
13
14#[derive(PartialEq, Eq, Hash, Clone)]
15struct Method {
16    pub plugin: String,
17    pub name: String,
18}
19
20impl Method {
21    fn new(plugin: String, name: String) -> Self {
22        Method { plugin, name }
23    }
24}
25
26/// Communication channel with DFHack.
27///
28/// Stores the existing bindings and keep an open socket.
29pub struct Channel {
30    stream: std::net::TcpStream,
31    bindings: HashMap<Method, i16>,
32}
33
34const MAGIC_QUERY: &str = "DFHack?\n";
35const MAGIC_REPLY: &str = "DFHack!\n";
36const VERSION: i32 = 1;
37
38const BIND_METHOD_ID: i16 = 0;
39const RUN_COMMAND_ID: i16 = 1;
40
41impl dfhack_proto::Channel for Channel {
42    type TError = crate::Error;
43
44    fn request<TRequest, TReply>(
45        &mut self,
46        plugin: std::string::String,
47        name: std::string::String,
48        request: TRequest,
49    ) -> crate::Result<crate::Reply<TReply>>
50    where
51        TRequest: protobuf::MessageFull,
52        TReply: protobuf::MessageFull,
53    {
54        let method = Method::new(plugin, name);
55
56        // did not manage to use the entry api due to borrow checker
57        let maybe_id = self.bindings.get(&method);
58
59        let id = match maybe_id {
60            Some(id) => *id,
61            None => {
62                let id = self.bind_method::<TRequest, TReply>(&method)?;
63                self.bindings.insert(method, id);
64                id
65            }
66        };
67
68        self.request_raw(id, request)
69    }
70}
71
72impl Channel {
73    pub(crate) fn connect() -> crate::Result<Self> {
74        let port = match std::env::var("DFHACK_PORT") {
75            Ok(p) => p,
76            Err(_) => "5000".to_string(),
77        };
78        Self::connect_to(&format!("127.0.0.1:{}", port))
79    }
80
81    pub(crate) fn connect_to(address: &str) -> crate::Result<Channel> {
82        log::info!("Connecting to {}", address);
83        let mut client = Channel {
84            stream: std::net::TcpStream::connect(address)?,
85            bindings: HashMap::new(),
86        };
87
88        client.bindings.insert(
89            Method::new("".to_string(), "BindMethod".to_string()),
90            BIND_METHOD_ID,
91        );
92        client.bindings.insert(
93            Method::new("".to_string(), "RunCommand".to_string()),
94            RUN_COMMAND_ID,
95        );
96
97        let handshake_request = message::Handshake::new(MAGIC_QUERY.to_string(), VERSION);
98        handshake_request.send(&mut client.stream)?;
99        let handshake_reply = message::Handshake::receive(&mut client.stream)?;
100
101        if handshake_reply.magic != MAGIC_REPLY {
102            return Err(Error::ProtocolError(format!(
103                "Unexpected magic {}",
104                handshake_reply.magic
105            )));
106        }
107
108        if handshake_reply.version != VERSION {
109            return Err(Error::ProtocolError(format!(
110                "Unexpected magic version {}",
111                handshake_reply.version
112            )));
113        }
114
115        Ok(client)
116    }
117
118    fn request_raw<TIN: protobuf::MessageFull, TOUT: protobuf::MessageFull>(
119        &mut self,
120        id: i16,
121        message: TIN,
122    ) -> crate::Result<crate::Reply<TOUT>> {
123        let request = message::Request::new(id, message);
124        request.send(&mut self.stream)?;
125        let mut fragments = Vec::new();
126
127        loop {
128            let reply: message::Reply<TOUT> = message::Reply::receive(&mut self.stream)?;
129            match reply {
130                message::Reply::Text(text) => {
131                    for fragment in &text.fragments {
132                        log::info!("{}", fragment.text());
133                    }
134                    fragments.extend(text.fragments);
135                }
136                message::Reply::Result(result) => {
137                    return Ok(crate::Reply {
138                        reply: result,
139                        fragments,
140                    })
141                }
142                message::Reply::Fail(command_result) => {
143                    return Err(Error::RpcError {
144                        result: command_result,
145                        fragments,
146                    })
147                }
148            }
149        }
150    }
151
152    fn bind_method<TIN: protobuf::MessageFull, TOUT: protobuf::MessageFull>(
153        &mut self,
154        method: &Method,
155    ) -> crate::Result<i16> {
156        let input_descriptor = TIN::descriptor();
157        let output_descriptor = TOUT::descriptor();
158        let input_msg = input_descriptor.full_name();
159        let output_msg = output_descriptor.full_name();
160        self.bind_method_by_name(&method.plugin, &method.name, input_msg, output_msg)
161    }
162
163    fn bind_method_by_name(
164        &mut self,
165        plugin: &str,
166        method: &str,
167        input_msg: &str,
168        output_msg: &str,
169    ) -> crate::Result<i16> {
170        log::debug!("Binding the method {}:{}", plugin, method);
171        let mut request = crate::CoreBindRequest::new();
172        request.set_method(method.to_owned());
173        request.set_input_msg(input_msg.to_string());
174        request.set_output_msg(output_msg.to_string());
175        request.set_plugin(plugin.to_owned());
176        let reply: crate::CoreBindReply = match self.request_raw(BIND_METHOD_ID, request) {
177            Ok(reply) => reply.reply,
178            Err(_) => {
179                log::error!("Error attempting to bind {}", method);
180                return Err(Error::FailedToBind(format!(
181                    "{}::{} ({}->{})",
182                    plugin, method, input_msg, output_msg,
183                )));
184            }
185        };
186        let id = reply.assigned_id() as i16;
187        log::debug!("{}:{} bound to {}", plugin, method, id);
188        Ok(id)
189    }
190}
191
192impl Drop for Channel {
193    fn drop(&mut self) {
194        let quit = message::Quit::new();
195        let res = quit.send(&mut self.stream);
196        if let Err(failure) = res {
197            println!(
198                "Warning: failed to close the connection to dfhack-remote: {}",
199                failure
200            );
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    #[cfg(feature = "test-with-df")]
208    mod withdf {
209        use crate::Error;
210
211        #[test]
212        fn bind() {
213            use crate::channel::Channel;
214            let mut channel = Channel::connect().unwrap();
215
216            channel
217                .bind_method_by_name(
218                    "",
219                    "GetVersion",
220                    "dfproto.EmptyMessage",
221                    "dfproto.StringMessage",
222                )
223                .unwrap();
224        }
225
226        #[test]
227        fn bad_bind() {
228            use crate::channel::Channel;
229            let mut channel = Channel::connect().unwrap();
230
231            let err = channel
232                .bind_method_by_name(
233                    "",
234                    "GetVersion",
235                    "dfproto.EmptyMessage",
236                    "dfproto.EmptyMessage",
237                )
238                .unwrap_err();
239            assert!(std::matches!(err, Error::FailedToBind(_)));
240
241            let err = channel
242                .bind_method_by_name(
243                    "dorf",
244                    "GetVersion",
245                    "dfproto.StringMessage",
246                    "dfproto.EmptyMessage",
247                )
248                .unwrap_err();
249            assert!(std::matches!(err, Error::FailedToBind(_)));
250        }
251
252        #[test]
253        #[cfg(feature = "reflection")]
254        fn bind_all() {
255            use dfhack_proto::{reflection::StubReflection, stubs::Stubs};
256
257            use crate::channel::Channel;
258            let mut channel = Channel::connect().unwrap();
259            let methods = Stubs::<Channel>::list_methods();
260
261            for method in &methods {
262                channel
263                    .bind_method_by_name(
264                        &method.plugin_name,
265                        &method.name,
266                        &method.input_type,
267                        &method.output_type,
268                    )
269                    .unwrap();
270            }
271        }
272    }
273}