ola-rs 0.1.0

Rust client for Open Lighting Architecture RPC DMX control
Documentation
//! Rust client for Open Lighting Architecture RPC DMX control.
//!
//! This crate implements the small OLA RPC surface needed by live visual and
//! lighting tools: update DMX, stream DMX, read DMX, and blackout universes.
//! It is a clean Rust implementation informed by the public OLA RPC protocol.

use prost::Message;
use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
use thiserror::Error;

const PROTOCOL_VERSION: u32 = 1;
const VERSION_MASK: u32 = 0xf000_0000;
const SIZE_MASK: u32 = 0x0fff_ffff;
const DEFAULT_OLA_PORT: u16 = 9010;

#[derive(Debug, Error)]
pub enum OlaError {
    #[error("I/O error: {0}")]
    Io(#[from] std::io::Error),
    #[error("protobuf encode error: {0}")]
    Encode(#[from] prost::EncodeError),
    #[error("protobuf decode error: {0}")]
    Decode(#[from] prost::DecodeError),
    #[error("unsupported OLA RPC protocol version {0}")]
    UnsupportedProtocolVersion(u32),
    #[error("OLA RPC failed: {0}")]
    RpcFailed(String),
    #[error("unexpected OLA RPC response type {0}")]
    UnexpectedResponseType(i32),
    #[error("response id mismatch: expected {expected}, got {actual}")]
    ResponseIdMismatch { expected: u32, actual: u32 },
    #[error("DMX frame length {0} exceeds 512 bytes")]
    DmxFrameTooLong(usize),
}

pub type Result<T> = std::result::Result<T, OlaError>;

#[derive(Clone, PartialEq, Message)]
struct RpcMessage {
    #[prost(enumeration = "RpcType", required, tag = "1")]
    r#type: i32,
    #[prost(uint32, optional, tag = "2")]
    id: Option<u32>,
    #[prost(string, optional, tag = "3")]
    name: Option<String>,
    #[prost(bytes, optional, tag = "4")]
    buffer: Option<Vec<u8>>,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, prost::Enumeration)]
#[repr(i32)]
enum RpcType {
    Request = 1,
    Response = 2,
    ResponseCancel = 3,
    ResponseFailed = 4,
    ResponseNotImplemented = 5,
    Disconnect = 6,
    DescriptorRequest = 7,
    DescriptorResponse = 8,
    RequestCancel = 9,
    StreamRequest = 10,
}

#[derive(Clone, PartialEq, Message)]
pub struct Ack {
    #[prost(bool, required, tag = "1")]
    pub success: bool,
}

#[derive(Clone, PartialEq, Message)]
pub struct DmxData {
    #[prost(int32, required, tag = "1")]
    pub universe: i32,
    #[prost(bytes, required, tag = "2")]
    pub data: Vec<u8>,
    #[prost(int32, optional, tag = "3")]
    pub priority: Option<i32>,
}

#[derive(Clone, PartialEq, Message)]
pub struct UniverseRequest {
    #[prost(int32, required, tag = "1")]
    pub universe: i32,
}

#[derive(Debug, Clone)]
pub struct OlaConfig {
    pub host: String,
    pub port: u16,
    pub connect_timeout: Duration,
    pub read_timeout: Option<Duration>,
    pub write_timeout: Option<Duration>,
}

impl Default for OlaConfig {
    fn default() -> Self {
        Self {
            host: "127.0.0.1".to_string(),
            port: DEFAULT_OLA_PORT,
            connect_timeout: Duration::from_secs(2),
            read_timeout: Some(Duration::from_secs(2)),
            write_timeout: Some(Duration::from_secs(2)),
        }
    }
}

pub struct OlaClient {
    stream: TcpStream,
    next_id: u32,
}

impl OlaClient {
    pub fn connect(config: OlaConfig) -> Result<Self> {
        let addr = (config.host.as_str(), config.port)
            .to_socket_addrs()?
            .next()
            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "no address resolved"))?;
        let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
        stream.set_nodelay(true)?;
        stream.set_read_timeout(config.read_timeout)?;
        stream.set_write_timeout(config.write_timeout)?;
        Ok(Self { stream, next_id: 0 })
    }

    pub fn connect_default() -> Result<Self> {
        Self::connect(OlaConfig::default())
    }

    pub fn update_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<Ack> {
        validate_dmx(data)?;
        let request = DmxData {
            universe,
            data: data.to_vec(),
            priority,
        };
        self.request("UpdateDmxData", &request)
    }

    pub fn stream_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<()> {
        validate_dmx(data)?;
        let request = DmxData {
            universe,
            data: data.to_vec(),
            priority,
        };
        self.stream_request("StreamDmxData", &request)
    }

    pub fn get_dmx(&mut self, universe: i32) -> Result<DmxData> {
        let request = UniverseRequest { universe };
        self.request("GetDmx", &request)
    }

    pub fn blackout(&mut self, universe: i32) -> Result<Ack> {
        self.update_dmx(universe, &[0; 512], None)
    }

    pub fn stream_blackout(&mut self, universe: i32) -> Result<()> {
        self.stream_dmx(universe, &[0; 512], None)
    }

    fn request<M, R>(&mut self, name: &str, message: &M) -> Result<R>
    where
        M: Message,
        R: Message + Default,
    {
        let id = self.next_request_id();
        let wrapper = RpcMessage {
            r#type: RpcType::Request as i32,
            id: Some(id),
            name: Some(name.to_string()),
            buffer: Some(message.encode_to_vec()),
        };
        self.write_wrapper(&wrapper)?;
        let response = self.read_wrapper()?;
        self.decode_response(id, response)
    }

    fn stream_request<M>(&mut self, name: &str, message: &M) -> Result<()>
    where
        M: Message,
    {
        let id = self.next_request_id();
        let wrapper = RpcMessage {
            r#type: RpcType::StreamRequest as i32,
            id: Some(id),
            name: Some(name.to_string()),
            buffer: Some(message.encode_to_vec()),
        };
        self.write_wrapper(&wrapper)
    }

    fn next_request_id(&mut self) -> u32 {
        self.next_id = if self.next_id == i32::MAX as u32 { 1 } else { self.next_id + 1 };
        self.next_id
    }

    fn write_wrapper(&mut self, wrapper: &RpcMessage) -> Result<()> {
        let body = wrapper.encode_to_vec();
        let header = build_header(body.len())?.to_ne_bytes();
        self.stream.write_all(&header)?;
        self.stream.write_all(&body)?;
        self.stream.flush()?;
        Ok(())
    }

    fn read_wrapper(&mut self) -> Result<RpcMessage> {
        let mut header = [0u8; 4];
        self.stream.read_exact(&mut header)?;
        let len = parse_header(u32::from_ne_bytes(header))?;
        let mut body = vec![0u8; len];
        self.stream.read_exact(&mut body)?;
        Ok(RpcMessage::decode(body.as_slice())?)
    }

    fn decode_response<R>(&self, expected_id: u32, response: RpcMessage) -> Result<R>
    where
        R: Message + Default,
    {
        let actual_id = response.id.unwrap_or_default();
        if actual_id != expected_id {
            return Err(OlaError::ResponseIdMismatch {
                expected: expected_id,
                actual: actual_id,
            });
        }

        match response.r#type {
            x if x == RpcType::Response as i32 => {
                let buffer = response.buffer.unwrap_or_default();
                Ok(R::decode(buffer.as_slice())?)
            }
            x if x == RpcType::ResponseFailed as i32 => {
                let buffer = response.buffer.unwrap_or_default();
                let message = String::from_utf8_lossy(&buffer).to_string();
                Err(OlaError::RpcFailed(message))
            }
            other => Err(OlaError::UnexpectedResponseType(other)),
        }
    }
}

fn validate_dmx(data: &[u8]) -> Result<()> {
    if data.len() > 512 {
        return Err(OlaError::DmxFrameTooLong(data.len()));
    }
    Ok(())
}

fn build_header(length: usize) -> Result<u32> {
    let length = u32::try_from(length).map_err(|_| OlaError::DmxFrameTooLong(length))?;
    Ok(((PROTOCOL_VERSION << 28) & VERSION_MASK) | (length & SIZE_MASK))
}

fn parse_header(header: u32) -> Result<usize> {
    let version = (header & VERSION_MASK) >> 28;
    if version != PROTOCOL_VERSION {
        return Err(OlaError::UnsupportedProtocolVersion(version));
    }
    Ok((header & SIZE_MASK) as usize)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn header_roundtrip() {
        let header = build_header(1234).unwrap();
        assert_eq!(parse_header(header).unwrap(), 1234);
    }

    #[test]
    fn rejects_oversized_dmx() {
        let data = vec![0u8; 513];
        assert!(matches!(validate_dmx(&data), Err(OlaError::DmxFrameTooLong(513))));
    }

    #[test]
    fn dmx_data_encodes() {
        let data = DmxData {
            universe: 1,
            data: vec![1, 2, 3],
            priority: Some(100),
        };
        let encoded = data.encode_to_vec();
        let decoded = DmxData::decode(encoded.as_slice()).unwrap();
        assert_eq!(decoded.universe, 1);
        assert_eq!(decoded.data, vec![1, 2, 3]);
        assert_eq!(decoded.priority, Some(100));
    }
}