1use std::collections::HashMap;
8
9use crate::{
10 message::{self, Receive, Send},
11 Error,
12};
13
14#[derive(PartialEq, Eq, Hash, Clone)]
15struct Method {
16 pub plugin: String,
17 pub name: String,
18}
19
20impl Method {
21 fn new(plugin: String, name: String) -> Self {
22 Method { plugin, name }
23 }
24}
25
26pub struct Channel {
30 stream: std::net::TcpStream,
31 bindings: HashMap<Method, i16>,
32}
33
34const MAGIC_QUERY: &str = "DFHack?\n";
35const MAGIC_REPLY: &str = "DFHack!\n";
36const VERSION: i32 = 1;
37
38const BIND_METHOD_ID: i16 = 0;
39const RUN_COMMAND_ID: i16 = 1;
40
41impl dfhack_proto::Channel for Channel {
42 type TError = crate::Error;
43
44 fn request<TRequest, TReply>(
45 &mut self,
46 plugin: std::string::String,
47 name: std::string::String,
48 request: TRequest,
49 ) -> crate::Result<crate::Reply<TReply>>
50 where
51 TRequest: protobuf::MessageFull,
52 TReply: protobuf::MessageFull,
53 {
54 let method = Method::new(plugin, name);
55
56 let maybe_id = self.bindings.get(&method);
58
59 let id = match maybe_id {
60 Some(id) => *id,
61 None => {
62 let id = self.bind_method::<TRequest, TReply>(&method)?;
63 self.bindings.insert(method, id);
64 id
65 }
66 };
67
68 self.request_raw(id, request)
69 }
70}
71
72impl Channel {
73 pub(crate) fn connect() -> crate::Result<Self> {
74 let port = match std::env::var("DFHACK_PORT") {
75 Ok(p) => p,
76 Err(_) => "5000".to_string(),
77 };
78 Self::connect_to(&format!("127.0.0.1:{}", port))
79 }
80
81 pub(crate) fn connect_to(address: &str) -> crate::Result<Channel> {
82 log::info!("Connecting to {}", address);
83 let mut client = Channel {
84 stream: std::net::TcpStream::connect(address)?,
85 bindings: HashMap::new(),
86 };
87
88 client.bindings.insert(
89 Method::new("".to_string(), "BindMethod".to_string()),
90 BIND_METHOD_ID,
91 );
92 client.bindings.insert(
93 Method::new("".to_string(), "RunCommand".to_string()),
94 RUN_COMMAND_ID,
95 );
96
97 let handshake_request = message::Handshake::new(MAGIC_QUERY.to_string(), VERSION);
98 handshake_request.send(&mut client.stream)?;
99 let handshake_reply = message::Handshake::receive(&mut client.stream)?;
100
101 if handshake_reply.magic != MAGIC_REPLY {
102 return Err(Error::ProtocolError(format!(
103 "Unexpected magic {}",
104 handshake_reply.magic
105 )));
106 }
107
108 if handshake_reply.version != VERSION {
109 return Err(Error::ProtocolError(format!(
110 "Unexpected magic version {}",
111 handshake_reply.version
112 )));
113 }
114
115 Ok(client)
116 }
117
118 fn request_raw<TIN: protobuf::MessageFull, TOUT: protobuf::MessageFull>(
119 &mut self,
120 id: i16,
121 message: TIN,
122 ) -> crate::Result<crate::Reply<TOUT>> {
123 let request = message::Request::new(id, message);
124 request.send(&mut self.stream)?;
125 let mut fragments = Vec::new();
126
127 loop {
128 let reply: message::Reply<TOUT> = message::Reply::receive(&mut self.stream)?;
129 match reply {
130 message::Reply::Text(text) => {
131 for fragment in &text.fragments {
132 log::info!("{}", fragment.text());
133 }
134 fragments.extend(text.fragments);
135 }
136 message::Reply::Result(result) => {
137 return Ok(crate::Reply {
138 reply: result,
139 fragments,
140 })
141 }
142 message::Reply::Fail(command_result) => {
143 return Err(Error::RpcError {
144 result: command_result,
145 fragments,
146 })
147 }
148 }
149 }
150 }
151
152 fn bind_method<TIN: protobuf::MessageFull, TOUT: protobuf::MessageFull>(
153 &mut self,
154 method: &Method,
155 ) -> crate::Result<i16> {
156 let input_descriptor = TIN::descriptor();
157 let output_descriptor = TOUT::descriptor();
158 let input_msg = input_descriptor.full_name();
159 let output_msg = output_descriptor.full_name();
160 self.bind_method_by_name(&method.plugin, &method.name, input_msg, output_msg)
161 }
162
163 fn bind_method_by_name(
164 &mut self,
165 plugin: &str,
166 method: &str,
167 input_msg: &str,
168 output_msg: &str,
169 ) -> crate::Result<i16> {
170 log::debug!("Binding the method {}:{}", plugin, method);
171 let mut request = crate::CoreBindRequest::new();
172 request.set_method(method.to_owned());
173 request.set_input_msg(input_msg.to_string());
174 request.set_output_msg(output_msg.to_string());
175 request.set_plugin(plugin.to_owned());
176 let reply: crate::CoreBindReply = match self.request_raw(BIND_METHOD_ID, request) {
177 Ok(reply) => reply.reply,
178 Err(_) => {
179 log::error!("Error attempting to bind {}", method);
180 return Err(Error::FailedToBind(format!(
181 "{}::{} ({}->{})",
182 plugin, method, input_msg, output_msg,
183 )));
184 }
185 };
186 let id = reply.assigned_id() as i16;
187 log::debug!("{}:{} bound to {}", plugin, method, id);
188 Ok(id)
189 }
190}
191
192impl Drop for Channel {
193 fn drop(&mut self) {
194 let quit = message::Quit::new();
195 let res = quit.send(&mut self.stream);
196 if let Err(failure) = res {
197 println!(
198 "Warning: failed to close the connection to dfhack-remote: {}",
199 failure
200 );
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 #[cfg(feature = "test-with-df")]
208 mod withdf {
209 use crate::Error;
210
211 #[test]
212 fn bind() {
213 use crate::channel::Channel;
214 let mut channel = Channel::connect().unwrap();
215
216 channel
217 .bind_method_by_name(
218 "",
219 "GetVersion",
220 "dfproto.EmptyMessage",
221 "dfproto.StringMessage",
222 )
223 .unwrap();
224 }
225
226 #[test]
227 fn bad_bind() {
228 use crate::channel::Channel;
229 let mut channel = Channel::connect().unwrap();
230
231 let err = channel
232 .bind_method_by_name(
233 "",
234 "GetVersion",
235 "dfproto.EmptyMessage",
236 "dfproto.EmptyMessage",
237 )
238 .unwrap_err();
239 assert!(std::matches!(err, Error::FailedToBind(_)));
240
241 let err = channel
242 .bind_method_by_name(
243 "dorf",
244 "GetVersion",
245 "dfproto.StringMessage",
246 "dfproto.EmptyMessage",
247 )
248 .unwrap_err();
249 assert!(std::matches!(err, Error::FailedToBind(_)));
250 }
251
252 #[test]
253 #[cfg(feature = "reflection")]
254 fn bind_all() {
255 use dfhack_proto::{reflection::StubReflection, stubs::Stubs};
256
257 use crate::channel::Channel;
258 let mut channel = Channel::connect().unwrap();
259 let methods = Stubs::<Channel>::list_methods();
260
261 for method in &methods {
262 channel
263 .bind_method_by_name(
264 &method.plugin_name,
265 &method.name,
266 &method.input_type,
267 &method.output_type,
268 )
269 .unwrap();
270 }
271 }
272 }
273}