use prost::Message;
use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
use thiserror::Error;
const PROTOCOL_VERSION: u32 = 1;
const VERSION_MASK: u32 = 0xf000_0000;
const SIZE_MASK: u32 = 0x0fff_ffff;
const DEFAULT_OLA_PORT: u16 = 9010;
#[derive(Debug, Error)]
pub enum OlaError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("protobuf encode error: {0}")]
Encode(#[from] prost::EncodeError),
#[error("protobuf decode error: {0}")]
Decode(#[from] prost::DecodeError),
#[error("unsupported OLA RPC protocol version {0}")]
UnsupportedProtocolVersion(u32),
#[error("OLA RPC failed: {0}")]
RpcFailed(String),
#[error("unexpected OLA RPC response type {0}")]
UnexpectedResponseType(i32),
#[error("response id mismatch: expected {expected}, got {actual}")]
ResponseIdMismatch { expected: u32, actual: u32 },
#[error("DMX frame length {0} exceeds 512 bytes")]
DmxFrameTooLong(usize),
}
pub type Result<T> = std::result::Result<T, OlaError>;
#[derive(Clone, PartialEq, Message)]
struct RpcMessage {
#[prost(enumeration = "RpcType", required, tag = "1")]
r#type: i32,
#[prost(uint32, optional, tag = "2")]
id: Option<u32>,
#[prost(string, optional, tag = "3")]
name: Option<String>,
#[prost(bytes, optional, tag = "4")]
buffer: Option<Vec<u8>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, prost::Enumeration)]
#[repr(i32)]
enum RpcType {
Request = 1,
Response = 2,
ResponseCancel = 3,
ResponseFailed = 4,
ResponseNotImplemented = 5,
Disconnect = 6,
DescriptorRequest = 7,
DescriptorResponse = 8,
RequestCancel = 9,
StreamRequest = 10,
}
#[derive(Clone, PartialEq, Message)]
pub struct Ack {
#[prost(bool, required, tag = "1")]
pub success: bool,
}
#[derive(Clone, PartialEq, Message)]
pub struct DmxData {
#[prost(int32, required, tag = "1")]
pub universe: i32,
#[prost(bytes, required, tag = "2")]
pub data: Vec<u8>,
#[prost(int32, optional, tag = "3")]
pub priority: Option<i32>,
}
#[derive(Clone, PartialEq, Message)]
pub struct UniverseRequest {
#[prost(int32, required, tag = "1")]
pub universe: i32,
}
#[derive(Debug, Clone)]
pub struct OlaConfig {
pub host: String,
pub port: u16,
pub connect_timeout: Duration,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
}
impl Default for OlaConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: DEFAULT_OLA_PORT,
connect_timeout: Duration::from_secs(2),
read_timeout: Some(Duration::from_secs(2)),
write_timeout: Some(Duration::from_secs(2)),
}
}
}
pub struct OlaClient {
stream: TcpStream,
next_id: u32,
}
impl OlaClient {
pub fn connect(config: OlaConfig) -> Result<Self> {
let addr = (config.host.as_str(), config.port)
.to_socket_addrs()?
.next()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "no address resolved"))?;
let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
stream.set_nodelay(true)?;
stream.set_read_timeout(config.read_timeout)?;
stream.set_write_timeout(config.write_timeout)?;
Ok(Self { stream, next_id: 0 })
}
pub fn connect_default() -> Result<Self> {
Self::connect(OlaConfig::default())
}
pub fn update_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<Ack> {
validate_dmx(data)?;
let request = DmxData {
universe,
data: data.to_vec(),
priority,
};
self.request("UpdateDmxData", &request)
}
pub fn stream_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<()> {
validate_dmx(data)?;
let request = DmxData {
universe,
data: data.to_vec(),
priority,
};
self.stream_request("StreamDmxData", &request)
}
pub fn get_dmx(&mut self, universe: i32) -> Result<DmxData> {
let request = UniverseRequest { universe };
self.request("GetDmx", &request)
}
pub fn blackout(&mut self, universe: i32) -> Result<Ack> {
self.update_dmx(universe, &[0; 512], None)
}
pub fn stream_blackout(&mut self, universe: i32) -> Result<()> {
self.stream_dmx(universe, &[0; 512], None)
}
fn request<M, R>(&mut self, name: &str, message: &M) -> Result<R>
where
M: Message,
R: Message + Default,
{
let id = self.next_request_id();
let wrapper = RpcMessage {
r#type: RpcType::Request as i32,
id: Some(id),
name: Some(name.to_string()),
buffer: Some(message.encode_to_vec()),
};
self.write_wrapper(&wrapper)?;
let response = self.read_wrapper()?;
self.decode_response(id, response)
}
fn stream_request<M>(&mut self, name: &str, message: &M) -> Result<()>
where
M: Message,
{
let id = self.next_request_id();
let wrapper = RpcMessage {
r#type: RpcType::StreamRequest as i32,
id: Some(id),
name: Some(name.to_string()),
buffer: Some(message.encode_to_vec()),
};
self.write_wrapper(&wrapper)
}
fn next_request_id(&mut self) -> u32 {
self.next_id = if self.next_id == i32::MAX as u32 { 1 } else { self.next_id + 1 };
self.next_id
}
fn write_wrapper(&mut self, wrapper: &RpcMessage) -> Result<()> {
let body = wrapper.encode_to_vec();
let header = build_header(body.len())?.to_ne_bytes();
self.stream.write_all(&header)?;
self.stream.write_all(&body)?;
self.stream.flush()?;
Ok(())
}
fn read_wrapper(&mut self) -> Result<RpcMessage> {
let mut header = [0u8; 4];
self.stream.read_exact(&mut header)?;
let len = parse_header(u32::from_ne_bytes(header))?;
let mut body = vec![0u8; len];
self.stream.read_exact(&mut body)?;
Ok(RpcMessage::decode(body.as_slice())?)
}
fn decode_response<R>(&self, expected_id: u32, response: RpcMessage) -> Result<R>
where
R: Message + Default,
{
let actual_id = response.id.unwrap_or_default();
if actual_id != expected_id {
return Err(OlaError::ResponseIdMismatch {
expected: expected_id,
actual: actual_id,
});
}
match response.r#type {
x if x == RpcType::Response as i32 => {
let buffer = response.buffer.unwrap_or_default();
Ok(R::decode(buffer.as_slice())?)
}
x if x == RpcType::ResponseFailed as i32 => {
let buffer = response.buffer.unwrap_or_default();
let message = String::from_utf8_lossy(&buffer).to_string();
Err(OlaError::RpcFailed(message))
}
other => Err(OlaError::UnexpectedResponseType(other)),
}
}
}
fn validate_dmx(data: &[u8]) -> Result<()> {
if data.len() > 512 {
return Err(OlaError::DmxFrameTooLong(data.len()));
}
Ok(())
}
fn build_header(length: usize) -> Result<u32> {
let length = u32::try_from(length).map_err(|_| OlaError::DmxFrameTooLong(length))?;
Ok(((PROTOCOL_VERSION << 28) & VERSION_MASK) | (length & SIZE_MASK))
}
fn parse_header(header: u32) -> Result<usize> {
let version = (header & VERSION_MASK) >> 28;
if version != PROTOCOL_VERSION {
return Err(OlaError::UnsupportedProtocolVersion(version));
}
Ok((header & SIZE_MASK) as usize)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_roundtrip() {
let header = build_header(1234).unwrap();
assert_eq!(parse_header(header).unwrap(), 1234);
}
#[test]
fn rejects_oversized_dmx() {
let data = vec![0u8; 513];
assert!(matches!(validate_dmx(&data), Err(OlaError::DmxFrameTooLong(513))));
}
#[test]
fn dmx_data_encodes() {
let data = DmxData {
universe: 1,
data: vec![1, 2, 3],
priority: Some(100),
};
let encoded = data.encode_to_vec();
let decoded = DmxData::decode(encoded.as_slice()).unwrap();
assert_eq!(decoded.universe, 1);
assert_eq!(decoded.data, vec![1, 2, 3]);
assert_eq!(decoded.priority, Some(100));
}
}