1use std::collections::HashMap;
8
9use crate::{
10 message::{self, Receive, Send},
11 Error,
12};
13
14#[derive(PartialEq, Eq, Hash, Clone)]
15struct Method {
16 pub plugin: &'static str,
17 pub name: &'static str,
18}
19
20impl Method {
21 fn new(plugin: &'static str, name: &'static str) -> Self {
22 Method { plugin, name }
23 }
24}
25
26pub struct Channel {
30 stream: std::net::TcpStream,
31 bindings: HashMap<Method, i16>,
32}
33
34const MAGIC_QUERY: &str = "DFHack?\n";
35const MAGIC_REPLY: &str = "DFHack!\n";
36const VERSION: i32 = 1;
37
38const BIND_METHOD_ID: i16 = 0;
39const RUN_COMMAND_ID: i16 = 1;
40
41impl dfhack_proto::Channel for Channel {
42 type TError = crate::Error;
43
44 fn request<TRequest, TReply>(
45 &mut self,
46 plugin: &'static str,
47 name: &'static str,
48 request: TRequest,
49 ) -> crate::Result<crate::Reply<TReply>>
50 where
51 TRequest: crate::Message,
52 TReply: crate::Message,
53 {
54 let method = Method::new(plugin, name);
55
56 let maybe_id = self.bindings.get(&method);
58
59 let id = match maybe_id {
60 Some(id) => *id,
61 None => {
62 let id = self.bind_method::<TRequest, TReply>(&method)?;
63 self.bindings.insert(method, id);
64 id
65 }
66 };
67
68 self.request_raw(id, request)
69 }
70}
71
72impl Channel {
73 pub(crate) fn connect() -> crate::Result<Self> {
74 let port = match std::env::var("DFHACK_PORT") {
75 Ok(p) => p,
76 Err(_) => "5000".to_string(),
77 };
78 Self::connect_to(&format!("127.0.0.1:{}", port))
79 }
80
81 pub(crate) fn connect_to(address: &str) -> crate::Result<Channel> {
82 log::info!("Connecting to {}", address);
83 let mut client = Channel {
84 stream: std::net::TcpStream::connect(address)?,
85 bindings: HashMap::new(),
86 };
87
88 client
89 .bindings
90 .insert(Method::new("", "BindMethod"), BIND_METHOD_ID);
91 client
92 .bindings
93 .insert(Method::new("", "RunCommand"), RUN_COMMAND_ID);
94
95 let handshake_request = message::Handshake::new(MAGIC_QUERY.to_string(), VERSION);
96 handshake_request.send(&mut client.stream)?;
97 let handshake_reply = message::Handshake::receive(&mut client.stream)?;
98
99 if handshake_reply.magic != MAGIC_REPLY {
100 return Err(Error::ProtocolError(format!(
101 "Unexpected magic {}",
102 handshake_reply.magic
103 )));
104 }
105
106 if handshake_reply.version != VERSION {
107 return Err(Error::ProtocolError(format!(
108 "Unexpected magic version {}",
109 handshake_reply.version
110 )));
111 }
112
113 Ok(client)
114 }
115
116 fn request_raw<TIN: crate::Message, TOUT: crate::Message>(
117 &mut self,
118 id: i16,
119 message: TIN,
120 ) -> crate::Result<crate::Reply<TOUT>> {
121 let request = message::Request::new(id, message);
122 request.send(&mut self.stream)?;
123 let mut fragments = Vec::new();
124
125 loop {
126 let reply: message::Reply<TOUT> = message::Reply::receive(&mut self.stream)?;
127 match reply {
128 message::Reply::Text(text) => {
129 for fragment in &text.fragments {
130 log::info!("{}", fragment.text);
131 }
132 fragments.extend(text.fragments);
133 }
134 message::Reply::Result(result) => {
135 return Ok(crate::Reply {
136 reply: result,
137 fragments,
138 })
139 }
140 message::Reply::Fail(command_result) => {
141 return Err(Error::RpcError {
142 result: command_result,
143 fragments,
144 })
145 }
146 }
147 }
148 }
149
150 fn bind_method<TIN: crate::Message, TOUT: crate::Message>(
151 &mut self,
152 method: &Method,
153 ) -> crate::Result<i16> {
154 let input_msg = TIN::full_name();
155 let output_msg = TOUT::full_name();
156 self.bind_method_by_name(method.plugin, method.name, &input_msg, &output_msg)
157 }
158
159 fn bind_method_by_name(
160 &mut self,
161 plugin: &str,
162 method: &str,
163 input_msg: &str,
164 output_msg: &str,
165 ) -> crate::Result<i16> {
166 log::debug!("Binding the method {}:{}", plugin, method);
167 let request = crate::CoreBindRequest {
168 method: method.to_string(),
169 input_msg: input_msg.to_string(),
170 output_msg: output_msg.to_string(),
171 plugin: Some(plugin.to_string()),
172 };
173 let reply: crate::CoreBindReply = match self.request_raw(BIND_METHOD_ID, request) {
174 Ok(reply) => reply.reply,
175 Err(_) => {
176 log::error!("Error attempting to bind {}", method);
177 return Err(Error::FailedToBind(format!(
178 "{}::{} ({}->{})",
179 plugin, method, input_msg, output_msg,
180 )));
181 }
182 };
183 let id = reply.assigned_id as i16;
184 log::debug!("{}:{} bound to {}", plugin, method, id);
185 Ok(id)
186 }
187}
188
189impl Drop for Channel {
190 fn drop(&mut self) {
191 let quit = message::Quit::new();
192 let res = quit.send(&mut self.stream);
193 if let Err(failure) = res {
194 println!(
195 "Warning: failed to close the connection to dfhack-remote: {}",
196 failure
197 );
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 #[cfg(feature = "test-with-df")]
205 mod withdf {
206 use crate::Error;
207
208 #[test]
209 fn bind() {
210 use crate::channel::Channel;
211 let mut channel = Channel::connect().unwrap();
212
213 channel
214 .bind_method_by_name(
215 "",
216 "GetVersion",
217 "dfproto.EmptyMessage",
218 "dfproto.StringMessage",
219 )
220 .unwrap();
221 }
222
223 #[test]
224 fn bad_bind() {
225 use crate::channel::Channel;
226 let mut channel = Channel::connect().unwrap();
227
228 let err = channel
229 .bind_method_by_name(
230 "",
231 "GetVersion",
232 "dfproto.EmptyMessage",
233 "dfproto.EmptyMessage",
234 )
235 .unwrap_err();
236 assert!(std::matches!(err, Error::FailedToBind(_)));
237
238 let err = channel
239 .bind_method_by_name(
240 "dorf",
241 "GetVersion",
242 "dfproto.StringMessage",
243 "dfproto.EmptyMessage",
244 )
245 .unwrap_err();
246 assert!(std::matches!(err, Error::FailedToBind(_)));
247 }
248
249 #[test]
250 fn bind_all() {
251 use dfhack_proto::{reflection::StubReflection, stubs::Stubs};
252
253 use crate::channel::Channel;
254 let mut channel = Channel::connect().unwrap();
255 let methods = Stubs::<Channel>::list_methods();
256
257 for method in &methods {
258 channel
259 .bind_method_by_name(
260 &method.plugin_name,
261 &method.name,
262 &method.input_type,
263 &method.output_type,
264 )
265 .unwrap();
266 }
267 }
268 }
269}