pad_motion/
server.rs

1use std::thread::JoinHandle;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::io::Result;
5use std::net::UdpSocket;
6use std::net::SocketAddr;
7use std::collections::{HashMap, HashSet};
8use std::sync::Mutex;
9use rand::Rng;
10use std::time::Duration;
11
12use crate::protocol::*;
13
14#[derive(Copy, Clone, Debug, Default)]
15struct Slot {
16  controller_info: ControllerInfo,
17  controller_data: ControllerData
18}
19
20struct RequestedControllerData {
21  packet_number: u32,
22  slot_numbers: HashSet<u8>,
23  mac_addresses: HashSet<u64>
24}
25
26const DEFAULT_PORT: u16 = 26760;
27
28pub trait DsServer {
29  /// Starts background server thread.
30  fn start(self, countinue_running: Arc<AtomicBool>) -> JoinHandle<()>;
31
32  /// Update controller info (it will automatically send this data to connected clients).
33  fn update_controller_info(&self, controller_info: ControllerInfo);
34
35  /// Update controller data (it will automatically send this data to connected clients).
36  fn update_controller_data(&self, slot_number: u8, controller_data: ControllerData);
37}
38
39pub struct Server {
40  message_header: MessageHeader,
41  slots: Mutex<[Slot; 4]>,
42  connected_clients: Mutex<HashMap<SocketAddr, RequestedControllerData>>,
43  socket: UdpSocket,
44}
45
46impl Server {
47  /// Creates new server.
48  /// 
49  /// # Arguments
50  /// 
51  /// * `id` - server ID, pass `None` to use a random number.
52  /// * `address` - server's UDP socket address, if `None` is passed `127.0.0.1:26760` is used.
53  pub fn new(id: Option<u32>, address: Option<SocketAddr>) -> Result<Server> {
54    let mut rng = rand::thread_rng();
55
56    let server_id = match id {
57      Some(id) => id,
58      None => rng.gen()
59    };
60
61    let message_header = {
62      MessageHeader {
63        source: MessageSource::Server,
64        protocol_version: PROTOCOL_VERSION,
65        message_length: 0,
66        checksum: 0,
67        source_id: server_id
68      }
69    };
70
71    let slots = {
72      let mut slots: [Slot; 4] = [Default::default(); 4];
73      let mut i = 0;
74      for slot in slots.iter_mut() {
75        slot.controller_info.slot = i;
76        i += 1;
77      }
78
79      Mutex::new(slots)
80    };
81
82    let connected_clients = Mutex::new(HashMap::new());
83
84    let socket_address = match address {
85      Some(address) => address,
86      None => SocketAddr::from(([127, 0, 0, 1], DEFAULT_PORT))
87    };
88    let socket = UdpSocket::bind(socket_address)?;
89    socket.set_read_timeout(Some(Duration::from_secs_f64(0.2)))?;
90    socket.set_write_timeout(Some(Duration::from_secs_f64(0.2)))?;
91
92    Ok(Server {
93      message_header,
94      slots,
95      connected_clients,
96      socket
97    })
98  }
99
100  fn encode_and_send(&self, target: SocketAddr, message: Message) -> Result<()> {
101    let mut encoded_message = vec![];
102    encode_message(&mut encoded_message, message).unwrap();
103  
104    self.socket.send_to(&encoded_message, target).map(|_amount| ())
105  }
106
107  fn send_protocol_version(&self, target: SocketAddr) -> Result<()> {
108    let message = Message {
109      header: self.message_header,
110      message_type: MessageType::ConnectedControllers,
111      payload: MessagePayload::ProtocolVersion(PROTOCOL_VERSION)
112    };
113
114    self.encode_and_send(target, message)
115  }
116
117  fn send_connected_controller_info(&self, target: SocketAddr, slot_number: u8) -> Result<()> {
118    let controller_info = self.slots.lock().unwrap()[slot_number as usize].controller_info;
119
120    let payload = MessagePayload::ConnectedControllerResponse {
121      controller_info
122    };
123
124    let message = Message {
125      header: self.message_header,
126      message_type: MessageType::ConnectedControllers,
127      payload
128    };
129
130    self.encode_and_send(target, message)
131  }
132
133  fn send_slot_data(&self, target: SocketAddr, 
134                    slot: Slot, packet_number: &mut u32) -> Result<()> {
135    let payload = MessagePayload::ControllerData {
136      packet_number: *packet_number,
137      controller_info: slot.controller_info,
138      controller_data: slot.controller_data  
139    };
140
141    let message = Message {
142      header: self.message_header,
143      message_type: MessageType::ControllerData,
144      payload
145    };
146
147    let result = self.encode_and_send(target, message);
148    if result.is_ok() {
149      *packet_number += 1;
150    }
151
152    result
153  }
154
155  fn send_controller_data(&self) -> Result<()> {
156    let slots = self.slots.lock().unwrap();
157    let mut connected_clients = self.connected_clients.lock().unwrap();
158
159    connected_clients.retain(|&client_address, requested_controller_data| {
160      let mut already_sent = HashSet::new();
161
162      for &slot_number in requested_controller_data.slot_numbers.iter() {
163        let slot = slots[slot_number as usize];
164        let result = self.send_slot_data(client_address, 
165                                         slot, 
166                                         &mut requested_controller_data.packet_number);
167
168        if result.is_ok() {
169          already_sent.insert(slot_number);
170        } else {
171          return false;
172        }
173      }
174
175      for &mac_address in requested_controller_data.mac_addresses.iter() {
176        let slot_number = slots.iter().position(|slot| slot.controller_info.mac_address == mac_address);
177        if let Some(slot_number) = slot_number {
178          if !already_sent.contains(&(slot_number as u8)) {
179            let slot = slots[slot_number];
180            let result = self.send_slot_data(client_address, 
181                                             slot, 
182                                             &mut requested_controller_data.packet_number);
183
184            if result.is_ok() {
185              already_sent.insert(slot_number as u8);
186            } else {
187              return false;
188            }
189          }
190        }
191      }
192
193      true
194    });
195
196    Ok(())
197  }
198
199  fn handle_request(&self, source: SocketAddr, request: Message) -> Result<()> {
200    match request.message_type {
201      MessageType::ProtocolVersion => {
202        self.send_protocol_version(source)
203      },
204      _ => {
205        match request.payload {
206          MessagePayload::ConnectedControllersRequest { amount, 
207                                                        slot_numbers } => {
208            for i in 0..amount {
209              let slot_number = slot_numbers[i as usize];
210              self.send_connected_controller_info(source, slot_number)?;
211            }
212    
213            Ok(())
214          },
215          MessagePayload::ControllerDataRequest(request) => { 
216            {
217              let mut connected_clients = self.connected_clients.lock().unwrap();
218              let requested = connected_clients.entry(source).or_insert(RequestedControllerData {
219                packet_number: 0,
220                slot_numbers: HashSet::new(),
221                mac_addresses: HashSet::new()
222              });
223              
224              match request {
225                ControllerDataRequest::ReportAll => {
226                  requested.slot_numbers.insert(0);
227                  requested.slot_numbers.insert(1);
228                  requested.slot_numbers.insert(2);
229                  requested.slot_numbers.insert(3);
230                },
231                ControllerDataRequest::SlotNumber(slot_number) => {
232                  requested.slot_numbers.insert(slot_number);
233                },
234                ControllerDataRequest::MAC(mac) => {
235                  requested.mac_addresses.insert(mac);
236                }
237              };
238            }
239    
240            self.send_controller_data()
241          },
242          _ => Ok(()) // ignore request
243        }
244      }
245    }
246  }
247}
248
249impl DsServer for Arc<Server> {
250  fn start(self, countinue_running: Arc<AtomicBool>) -> JoinHandle<()> {
251    let countinue_running = countinue_running.clone();
252
253    std::thread::spawn(move || {
254      let mut buf = [0 as u8; 100];
255      while countinue_running.load(Ordering::SeqCst) {
256        match self.socket.recv_from(&mut buf) {
257          Ok((amount, source)) => {
258            let message = parse_message(MessageSource::Client, &buf[..amount], true);
259            if let Ok(message) = message {
260              let _ = self.handle_request(source, message);
261            }
262          },
263          _ => ()
264        }
265      }
266    })
267  }
268
269  fn update_controller_info(&self, controller_info: ControllerInfo) {
270    assert!(controller_info.slot < 4);
271
272    let slot_number = controller_info.slot;
273    {
274      let mut slots = self.slots.lock().unwrap();
275      slots[slot_number as usize].controller_info = controller_info;
276    }
277
278    let connected_clients = self.connected_clients.lock().unwrap();
279    for &address in connected_clients.keys() {
280      let _ = self.send_connected_controller_info(address, slot_number);
281    }
282  }
283
284  fn update_controller_data(&self, slot_number: u8, controller_data: ControllerData) {
285    assert!(slot_number < 4);
286
287    {
288      let mut slots = self.slots.lock().unwrap();
289      slots[slot_number as usize].controller_data = controller_data;
290    }
291
292    let _ = self.send_controller_data();
293  }
294}