Skip to main content

dfhack_remote/
channel.rs

1//! Internal module describing the exchange flow
2//!
3//! Implements the flow of sending and receiving messages.
4//! This includes the custom workflow of binding RPC methods
5//! before being able to use them.
6
7use std::collections::HashMap;
8
9use crate::{
10    message::{self, Receive, Send},
11    Error,
12};
13
14#[derive(PartialEq, Eq, Hash, Clone)]
15struct Method {
16    pub plugin: &'static str,
17    pub name: &'static str,
18}
19
20impl Method {
21    fn new(plugin: &'static str, name: &'static str) -> Self {
22        Method { plugin, name }
23    }
24}
25
26/// Communication channel with DFHack.
27///
28/// Stores the existing bindings and keep an open socket.
29pub struct Channel {
30    stream: std::net::TcpStream,
31    bindings: HashMap<Method, i16>,
32}
33
34const MAGIC_QUERY: &str = "DFHack?\n";
35const MAGIC_REPLY: &str = "DFHack!\n";
36const VERSION: i32 = 1;
37
38const BIND_METHOD_ID: i16 = 0;
39const RUN_COMMAND_ID: i16 = 1;
40
41impl dfhack_proto::Channel for Channel {
42    type TError = crate::Error;
43
44    fn request<TRequest, TReply>(
45        &mut self,
46        plugin: &'static str,
47        name: &'static str,
48        request: TRequest,
49    ) -> crate::Result<crate::Reply<TReply>>
50    where
51        TRequest: crate::Message,
52        TReply: crate::Message,
53    {
54        let method = Method::new(plugin, name);
55
56        // did not manage to use the entry api due to borrow checker
57        let maybe_id = self.bindings.get(&method);
58
59        let id = match maybe_id {
60            Some(id) => *id,
61            None => {
62                let id = self.bind_method::<TRequest, TReply>(&method)?;
63                self.bindings.insert(method, id);
64                id
65            }
66        };
67
68        self.request_raw(id, request)
69    }
70}
71
72impl Channel {
73    pub(crate) fn connect() -> crate::Result<Self> {
74        let port = match std::env::var("DFHACK_PORT") {
75            Ok(p) => p,
76            Err(_) => "5000".to_string(),
77        };
78        Self::connect_to(&format!("127.0.0.1:{}", port))
79    }
80
81    pub(crate) fn connect_to(address: &str) -> crate::Result<Channel> {
82        log::info!("Connecting to {}", address);
83        let mut client = Channel {
84            stream: std::net::TcpStream::connect(address)?,
85            bindings: HashMap::new(),
86        };
87
88        client
89            .bindings
90            .insert(Method::new("", "BindMethod"), BIND_METHOD_ID);
91        client
92            .bindings
93            .insert(Method::new("", "RunCommand"), RUN_COMMAND_ID);
94
95        let handshake_request = message::Handshake::new(MAGIC_QUERY.to_string(), VERSION);
96        handshake_request.send(&mut client.stream)?;
97        let handshake_reply = message::Handshake::receive(&mut client.stream)?;
98
99        if handshake_reply.magic != MAGIC_REPLY {
100            return Err(Error::ProtocolError(format!(
101                "Unexpected magic {}",
102                handshake_reply.magic
103            )));
104        }
105
106        if handshake_reply.version != VERSION {
107            return Err(Error::ProtocolError(format!(
108                "Unexpected magic version {}",
109                handshake_reply.version
110            )));
111        }
112
113        Ok(client)
114    }
115
116    fn request_raw<TIN: crate::Message, TOUT: crate::Message>(
117        &mut self,
118        id: i16,
119        message: TIN,
120    ) -> crate::Result<crate::Reply<TOUT>> {
121        let request = message::Request::new(id, message);
122        request.send(&mut self.stream)?;
123        let mut fragments = Vec::new();
124
125        loop {
126            let reply: message::Reply<TOUT> = message::Reply::receive(&mut self.stream)?;
127            match reply {
128                message::Reply::Text(text) => {
129                    for fragment in &text.fragments {
130                        log::info!("{}", fragment.text);
131                    }
132                    fragments.extend(text.fragments);
133                }
134                message::Reply::Result(result) => {
135                    return Ok(crate::Reply {
136                        reply: result,
137                        fragments,
138                    })
139                }
140                message::Reply::Fail(command_result) => {
141                    return Err(Error::RpcError {
142                        result: command_result,
143                        fragments,
144                    })
145                }
146            }
147        }
148    }
149
150    fn bind_method<TIN: crate::Message, TOUT: crate::Message>(
151        &mut self,
152        method: &Method,
153    ) -> crate::Result<i16> {
154        let input_msg = TIN::full_name();
155        let output_msg = TOUT::full_name();
156        self.bind_method_by_name(method.plugin, method.name, &input_msg, &output_msg)
157    }
158
159    fn bind_method_by_name(
160        &mut self,
161        plugin: &str,
162        method: &str,
163        input_msg: &str,
164        output_msg: &str,
165    ) -> crate::Result<i16> {
166        log::debug!("Binding the method {}:{}", plugin, method);
167        let request = crate::CoreBindRequest {
168            method: method.to_string(),
169            input_msg: input_msg.to_string(),
170            output_msg: output_msg.to_string(),
171            plugin: Some(plugin.to_string()),
172        };
173        let reply: crate::CoreBindReply = match self.request_raw(BIND_METHOD_ID, request) {
174            Ok(reply) => reply.reply,
175            Err(_) => {
176                log::error!("Error attempting to bind {}", method);
177                return Err(Error::FailedToBind(format!(
178                    "{}::{} ({}->{})",
179                    plugin, method, input_msg, output_msg,
180                )));
181            }
182        };
183        let id = reply.assigned_id as i16;
184        log::debug!("{}:{} bound to {}", plugin, method, id);
185        Ok(id)
186    }
187}
188
189impl Drop for Channel {
190    fn drop(&mut self) {
191        let quit = message::Quit::new();
192        let res = quit.send(&mut self.stream);
193        if let Err(failure) = res {
194            println!(
195                "Warning: failed to close the connection to dfhack-remote: {}",
196                failure
197            );
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    #[cfg(feature = "test-with-df")]
205    mod withdf {
206        use crate::Error;
207
208        #[test]
209        fn bind() {
210            use crate::channel::Channel;
211            let mut channel = Channel::connect().unwrap();
212
213            channel
214                .bind_method_by_name(
215                    "",
216                    "GetVersion",
217                    "dfproto.EmptyMessage",
218                    "dfproto.StringMessage",
219                )
220                .unwrap();
221        }
222
223        #[test]
224        fn bad_bind() {
225            use crate::channel::Channel;
226            let mut channel = Channel::connect().unwrap();
227
228            let err = channel
229                .bind_method_by_name(
230                    "",
231                    "GetVersion",
232                    "dfproto.EmptyMessage",
233                    "dfproto.EmptyMessage",
234                )
235                .unwrap_err();
236            assert!(std::matches!(err, Error::FailedToBind(_)));
237
238            let err = channel
239                .bind_method_by_name(
240                    "dorf",
241                    "GetVersion",
242                    "dfproto.StringMessage",
243                    "dfproto.EmptyMessage",
244                )
245                .unwrap_err();
246            assert!(std::matches!(err, Error::FailedToBind(_)));
247        }
248
249        #[test]
250        fn bind_all() {
251            use dfhack_proto::{reflection::StubReflection, stubs::Stubs};
252
253            use crate::channel::Channel;
254            let mut channel = Channel::connect().unwrap();
255            let methods = Stubs::<Channel>::list_methods();
256
257            for method in &methods {
258                channel
259                    .bind_method_by_name(
260                        &method.plugin_name,
261                        &method.name,
262                        &method.input_type,
263                        &method.output_type,
264                    )
265                    .unwrap();
266            }
267        }
268    }
269}