use crate::interpolation::*;
use crate::track::*;
use crate::Tracks;
use byteorder::ByteOrder;
use byteorder::{BigEndian, ReadBytesExt};
use std::hint::unreachable_unchecked;
use std::{
convert::TryFrom,
io::{self, Cursor, Read, Write},
net::{TcpStream, ToSocketAddrs},
};
use thiserror::Error;
const CLIENT_GREETING: &[u8] = b"hello, synctracker!";
const SERVER_GREETING: &[u8] = b"hello, demo!";
const SET_KEY: u8 = 0;
const DELETE_KEY: u8 = 1;
const GET_TRACK: u8 = 2;
const SET_ROW: u8 = 3;
const PAUSE: u8 = 4;
const SAVE_TRACKS: u8 = 5;
const SET_KEY_LEN: usize = 4 + 4 + 4 + 1;
const DELETE_KEY_LEN: usize = 4 + 4;
const GET_TRACK_LEN: usize = 4; const SET_ROW_LEN: usize = 4;
const PAUSE_LEN: usize = 1;
const MAX_COMMAND_LEN: usize = SET_KEY_LEN;
#[derive(Debug, Error)]
pub enum Error {
#[error("Failed to establish a TCP connection with the Rocket tracker")]
Connect(#[source] std::io::Error),
#[error("Handshake with the Rocket tracker failed")]
Handshake(#[source] std::io::Error),
#[error("The Rocket tracker greeting {0:?} wasn't correct")]
HandshakeGreetingMismatch([u8; SERVER_GREETING.len()]),
#[error("Cannot set Rocket's TCP connection to nonblocking mode")]
SetNonblocking(#[source] std::io::Error),
#[error("Rocket tracker disconnected")]
IOError(#[source] std::io::Error),
}
#[derive(Debug)]
enum ClientState {
New,
Incomplete(usize),
Complete,
}
#[derive(Debug, Copy, Clone)]
pub enum Event {
SetRow(u32),
Pause(bool),
SaveTracks,
}
#[derive(Debug)]
enum ReceiveResult {
Some(Event),
None,
Incomplete,
}
#[derive(Debug)]
pub struct RocketClient {
stream: TcpStream,
state: ClientState,
cmd: Vec<u8>,
tracks: Vec<Track>,
}
impl RocketClient {
pub fn new() -> Result<Self, Error> {
Self::connect(("localhost", 1338))
}
pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
let mut rocket = Self {
stream,
state: ClientState::New,
cmd: Vec::new(),
tracks: Vec::new(),
};
rocket.handshake()?;
rocket
.stream
.set_nonblocking(true)
.map_err(Error::SetNonblocking)?;
Ok(rocket)
}
pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
if let Some((i, _)) = self
.tracks
.iter()
.enumerate()
.find(|(_, t)| t.get_name() == name)
{
Ok(&mut self.tracks[i])
} else {
let mut buf = [GET_TRACK; 1 + GET_TRACK_LEN];
let name_len = u32::try_from(name.len()).expect("Track name too long");
BigEndian::write_u32(&mut buf[1..][..GET_TRACK_LEN], name_len);
self.stream.write_all(&buf).map_err(Error::IOError)?;
self.stream
.write_all(name.as_bytes())
.map_err(Error::IOError)?;
self.tracks.push(Track::new(name));
let track = self
.tracks
.last_mut()
.unwrap_or_else(|| unsafe { unreachable_unchecked() });
Ok(track)
}
}
pub fn get_track(&self, name: &str) -> Option<&Track> {
self.tracks.iter().find(|t| t.get_name() == name)
}
pub fn save_tracks(&self) -> &Tracks {
&self.tracks
}
pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
let mut buf = [SET_ROW; 1 + SET_ROW_LEN];
BigEndian::write_u32(&mut buf[1..][..SET_ROW_LEN], row);
self.stream.write_all(&buf).map_err(Error::IOError)
}
pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
loop {
match self.poll_event()? {
ReceiveResult::None => return Ok(None),
ReceiveResult::Incomplete => { }
ReceiveResult::Some(event) => return Ok(Some(event)),
}
}
}
fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
match self.state {
ClientState::New => self.poll_event_new(),
ClientState::Incomplete(bytes) => self.poll_event_incomplete(bytes),
ClientState::Complete => Ok(self.process_event().unwrap_or_else(|_| unreachable!())),
}
}
fn poll_event_new(&mut self) -> Result<ReceiveResult, Error> {
let mut buf = [0; 1];
match self.stream.read_exact(&mut buf) {
Ok(()) => {
self.cmd.extend_from_slice(&buf);
match self.cmd[0] {
SET_KEY => self.state = ClientState::Incomplete(SET_KEY_LEN),
DELETE_KEY => self.state = ClientState::Incomplete(DELETE_KEY_LEN),
SET_ROW => self.state = ClientState::Incomplete(SET_ROW_LEN),
PAUSE => self.state = ClientState::Incomplete(PAUSE_LEN),
SAVE_TRACKS => self.state = ClientState::Complete,
_ => self.state = ClientState::Complete, }
Ok(ReceiveResult::Incomplete)
}
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
_ => Err(Error::IOError(e)),
},
}
}
fn poll_event_incomplete(&mut self, bytes: usize) -> Result<ReceiveResult, Error> {
let mut buf = [0; MAX_COMMAND_LEN];
match self.stream.read(&mut buf[..bytes]) {
Ok(bytes_read) => {
self.cmd.extend_from_slice(&buf[..bytes_read]);
if bytes - bytes_read > 0 {
self.state = ClientState::Incomplete(bytes - bytes_read);
} else {
self.state = ClientState::Complete;
}
Ok(ReceiveResult::Incomplete)
}
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
_ => Err(Error::IOError(e)),
},
}
}
fn process_event(&mut self) -> Result<ReceiveResult, io::Error> {
let mut result = ReceiveResult::None;
let mut cursor = Cursor::new(&self.cmd);
let cmd = cursor.read_u8()?;
match cmd {
SET_KEY => {
let index = usize::try_from(cursor.read_u32::<BigEndian>()?).unwrap();
let track = &mut self.tracks[index];
let row = cursor.read_u32::<BigEndian>()?;
let value = cursor.read_f32::<BigEndian>()?;
let interpolation = Interpolation::from(cursor.read_u8()?);
let key = Key::new(row, value, interpolation);
track.set_key(key);
}
DELETE_KEY => {
let index = usize::try_from(cursor.read_u32::<BigEndian>()?).unwrap();
let track = &mut self.tracks[index];
let row = cursor.read_u32::<BigEndian>()?;
track.delete_key(row);
}
SET_ROW => {
let row = cursor.read_u32::<BigEndian>()?;
result = ReceiveResult::Some(Event::SetRow(row));
}
PAUSE => {
let flag = cursor.read_u8()? == 1;
result = ReceiveResult::Some(Event::Pause(flag));
}
SAVE_TRACKS => {
result = ReceiveResult::Some(Event::SaveTracks);
}
_ => eprintln!("rocket: Unknown command: {:?}", cmd),
}
self.cmd.clear();
self.state = ClientState::New;
Ok(result)
}
fn handshake(&mut self) -> Result<(), Error> {
self.stream
.write_all(CLIENT_GREETING)
.map_err(Error::Handshake)?;
let mut buf = [0; SERVER_GREETING.len()];
self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
if buf == SERVER_GREETING {
Ok(())
} else {
Err(Error::HandshakeGreetingMismatch(buf))
}
}
}