Skip to main content

ola_rs/
lib.rs

1//! Rust client for Open Lighting Architecture RPC DMX control.
2//!
3//! This crate implements the small OLA RPC surface needed by live visual and
4//! lighting tools: update DMX, stream DMX, read DMX, and blackout universes.
5//! It is a clean Rust implementation informed by the public OLA RPC protocol.
6
7use prost::Message;
8use std::io::{Read, Write};
9use std::net::{TcpStream, ToSocketAddrs};
10use std::time::Duration;
11use thiserror::Error;
12
13const PROTOCOL_VERSION: u32 = 1;
14const VERSION_MASK: u32 = 0xf000_0000;
15const SIZE_MASK: u32 = 0x0fff_ffff;
16const DEFAULT_OLA_PORT: u16 = 9010;
17
18#[derive(Debug, Error)]
19pub enum OlaError {
20    #[error("I/O error: {0}")]
21    Io(#[from] std::io::Error),
22    #[error("protobuf encode error: {0}")]
23    Encode(#[from] prost::EncodeError),
24    #[error("protobuf decode error: {0}")]
25    Decode(#[from] prost::DecodeError),
26    #[error("unsupported OLA RPC protocol version {0}")]
27    UnsupportedProtocolVersion(u32),
28    #[error("OLA RPC failed: {0}")]
29    RpcFailed(String),
30    #[error("unexpected OLA RPC response type {0}")]
31    UnexpectedResponseType(i32),
32    #[error("response id mismatch: expected {expected}, got {actual}")]
33    ResponseIdMismatch { expected: u32, actual: u32 },
34    #[error("DMX frame length {0} exceeds 512 bytes")]
35    DmxFrameTooLong(usize),
36}
37
38pub type Result<T> = std::result::Result<T, OlaError>;
39
40#[derive(Clone, PartialEq, Message)]
41struct RpcMessage {
42    #[prost(enumeration = "RpcType", required, tag = "1")]
43    r#type: i32,
44    #[prost(uint32, optional, tag = "2")]
45    id: Option<u32>,
46    #[prost(string, optional, tag = "3")]
47    name: Option<String>,
48    #[prost(bytes, optional, tag = "4")]
49    buffer: Option<Vec<u8>>,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq, prost::Enumeration)]
53#[repr(i32)]
54enum RpcType {
55    Request = 1,
56    Response = 2,
57    ResponseCancel = 3,
58    ResponseFailed = 4,
59    ResponseNotImplemented = 5,
60    Disconnect = 6,
61    DescriptorRequest = 7,
62    DescriptorResponse = 8,
63    RequestCancel = 9,
64    StreamRequest = 10,
65}
66
67#[derive(Clone, PartialEq, Message)]
68pub struct Ack {
69    #[prost(bool, required, tag = "1")]
70    pub success: bool,
71}
72
73#[derive(Clone, PartialEq, Message)]
74pub struct DmxData {
75    #[prost(int32, required, tag = "1")]
76    pub universe: i32,
77    #[prost(bytes, required, tag = "2")]
78    pub data: Vec<u8>,
79    #[prost(int32, optional, tag = "3")]
80    pub priority: Option<i32>,
81}
82
83#[derive(Clone, PartialEq, Message)]
84pub struct UniverseRequest {
85    #[prost(int32, required, tag = "1")]
86    pub universe: i32,
87}
88
89#[derive(Debug, Clone)]
90pub struct OlaConfig {
91    pub host: String,
92    pub port: u16,
93    pub connect_timeout: Duration,
94    pub read_timeout: Option<Duration>,
95    pub write_timeout: Option<Duration>,
96}
97
98impl Default for OlaConfig {
99    fn default() -> Self {
100        Self {
101            host: "127.0.0.1".to_string(),
102            port: DEFAULT_OLA_PORT,
103            connect_timeout: Duration::from_secs(2),
104            read_timeout: Some(Duration::from_secs(2)),
105            write_timeout: Some(Duration::from_secs(2)),
106        }
107    }
108}
109
110pub struct OlaClient {
111    stream: TcpStream,
112    next_id: u32,
113}
114
115impl OlaClient {
116    pub fn connect(config: OlaConfig) -> Result<Self> {
117        let addr = (config.host.as_str(), config.port)
118            .to_socket_addrs()?
119            .next()
120            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "no address resolved"))?;
121        let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
122        stream.set_nodelay(true)?;
123        stream.set_read_timeout(config.read_timeout)?;
124        stream.set_write_timeout(config.write_timeout)?;
125        Ok(Self { stream, next_id: 0 })
126    }
127
128    pub fn connect_default() -> Result<Self> {
129        Self::connect(OlaConfig::default())
130    }
131
132    pub fn update_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<Ack> {
133        validate_dmx(data)?;
134        let request = DmxData {
135            universe,
136            data: data.to_vec(),
137            priority,
138        };
139        self.request("UpdateDmxData", &request)
140    }
141
142    pub fn stream_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<()> {
143        validate_dmx(data)?;
144        let request = DmxData {
145            universe,
146            data: data.to_vec(),
147            priority,
148        };
149        self.stream_request("StreamDmxData", &request)
150    }
151
152    pub fn get_dmx(&mut self, universe: i32) -> Result<DmxData> {
153        let request = UniverseRequest { universe };
154        self.request("GetDmx", &request)
155    }
156
157    pub fn blackout(&mut self, universe: i32) -> Result<Ack> {
158        self.update_dmx(universe, &[0; 512], None)
159    }
160
161    pub fn stream_blackout(&mut self, universe: i32) -> Result<()> {
162        self.stream_dmx(universe, &[0; 512], None)
163    }
164
165    fn request<M, R>(&mut self, name: &str, message: &M) -> Result<R>
166    where
167        M: Message,
168        R: Message + Default,
169    {
170        let id = self.next_request_id();
171        let wrapper = RpcMessage {
172            r#type: RpcType::Request as i32,
173            id: Some(id),
174            name: Some(name.to_string()),
175            buffer: Some(message.encode_to_vec()),
176        };
177        self.write_wrapper(&wrapper)?;
178        let response = self.read_wrapper()?;
179        self.decode_response(id, response)
180    }
181
182    fn stream_request<M>(&mut self, name: &str, message: &M) -> Result<()>
183    where
184        M: Message,
185    {
186        let id = self.next_request_id();
187        let wrapper = RpcMessage {
188            r#type: RpcType::StreamRequest as i32,
189            id: Some(id),
190            name: Some(name.to_string()),
191            buffer: Some(message.encode_to_vec()),
192        };
193        self.write_wrapper(&wrapper)
194    }
195
196    fn next_request_id(&mut self) -> u32 {
197        self.next_id = if self.next_id == i32::MAX as u32 { 1 } else { self.next_id + 1 };
198        self.next_id
199    }
200
201    fn write_wrapper(&mut self, wrapper: &RpcMessage) -> Result<()> {
202        let body = wrapper.encode_to_vec();
203        let header = build_header(body.len())?.to_ne_bytes();
204        self.stream.write_all(&header)?;
205        self.stream.write_all(&body)?;
206        self.stream.flush()?;
207        Ok(())
208    }
209
210    fn read_wrapper(&mut self) -> Result<RpcMessage> {
211        let mut header = [0u8; 4];
212        self.stream.read_exact(&mut header)?;
213        let len = parse_header(u32::from_ne_bytes(header))?;
214        let mut body = vec![0u8; len];
215        self.stream.read_exact(&mut body)?;
216        Ok(RpcMessage::decode(body.as_slice())?)
217    }
218
219    fn decode_response<R>(&self, expected_id: u32, response: RpcMessage) -> Result<R>
220    where
221        R: Message + Default,
222    {
223        let actual_id = response.id.unwrap_or_default();
224        if actual_id != expected_id {
225            return Err(OlaError::ResponseIdMismatch {
226                expected: expected_id,
227                actual: actual_id,
228            });
229        }
230
231        match response.r#type {
232            x if x == RpcType::Response as i32 => {
233                let buffer = response.buffer.unwrap_or_default();
234                Ok(R::decode(buffer.as_slice())?)
235            }
236            x if x == RpcType::ResponseFailed as i32 => {
237                let buffer = response.buffer.unwrap_or_default();
238                let message = String::from_utf8_lossy(&buffer).to_string();
239                Err(OlaError::RpcFailed(message))
240            }
241            other => Err(OlaError::UnexpectedResponseType(other)),
242        }
243    }
244}
245
246fn validate_dmx(data: &[u8]) -> Result<()> {
247    if data.len() > 512 {
248        return Err(OlaError::DmxFrameTooLong(data.len()));
249    }
250    Ok(())
251}
252
253fn build_header(length: usize) -> Result<u32> {
254    let length = u32::try_from(length).map_err(|_| OlaError::DmxFrameTooLong(length))?;
255    Ok(((PROTOCOL_VERSION << 28) & VERSION_MASK) | (length & SIZE_MASK))
256}
257
258fn parse_header(header: u32) -> Result<usize> {
259    let version = (header & VERSION_MASK) >> 28;
260    if version != PROTOCOL_VERSION {
261        return Err(OlaError::UnsupportedProtocolVersion(version));
262    }
263    Ok((header & SIZE_MASK) as usize)
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn header_roundtrip() {
272        let header = build_header(1234).unwrap();
273        assert_eq!(parse_header(header).unwrap(), 1234);
274    }
275
276    #[test]
277    fn rejects_oversized_dmx() {
278        let data = vec![0u8; 513];
279        assert!(matches!(validate_dmx(&data), Err(OlaError::DmxFrameTooLong(513))));
280    }
281
282    #[test]
283    fn dmx_data_encodes() {
284        let data = DmxData {
285            universe: 1,
286            data: vec![1, 2, 3],
287            priority: Some(100),
288        };
289        let encoded = data.encode_to_vec();
290        let decoded = DmxData::decode(encoded.as_slice()).unwrap();
291        assert_eq!(decoded.universe, 1);
292        assert_eq!(decoded.data, vec![1, 2, 3]);
293        assert_eq!(decoded.priority, Some(100));
294    }
295}