1use prost::Message;
8use std::io::{Read, Write};
9use std::net::{TcpStream, ToSocketAddrs};
10use std::time::Duration;
11use thiserror::Error;
12
13const PROTOCOL_VERSION: u32 = 1;
14const VERSION_MASK: u32 = 0xf000_0000;
15const SIZE_MASK: u32 = 0x0fff_ffff;
16const DEFAULT_OLA_PORT: u16 = 9010;
17
18#[derive(Debug, Error)]
19pub enum OlaError {
20 #[error("I/O error: {0}")]
21 Io(#[from] std::io::Error),
22 #[error("protobuf encode error: {0}")]
23 Encode(#[from] prost::EncodeError),
24 #[error("protobuf decode error: {0}")]
25 Decode(#[from] prost::DecodeError),
26 #[error("unsupported OLA RPC protocol version {0}")]
27 UnsupportedProtocolVersion(u32),
28 #[error("OLA RPC failed: {0}")]
29 RpcFailed(String),
30 #[error("unexpected OLA RPC response type {0}")]
31 UnexpectedResponseType(i32),
32 #[error("response id mismatch: expected {expected}, got {actual}")]
33 ResponseIdMismatch { expected: u32, actual: u32 },
34 #[error("DMX frame length {0} exceeds 512 bytes")]
35 DmxFrameTooLong(usize),
36}
37
38pub type Result<T> = std::result::Result<T, OlaError>;
39
40#[derive(Clone, PartialEq, Message)]
41struct RpcMessage {
42 #[prost(enumeration = "RpcType", required, tag = "1")]
43 r#type: i32,
44 #[prost(uint32, optional, tag = "2")]
45 id: Option<u32>,
46 #[prost(string, optional, tag = "3")]
47 name: Option<String>,
48 #[prost(bytes, optional, tag = "4")]
49 buffer: Option<Vec<u8>>,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq, prost::Enumeration)]
53#[repr(i32)]
54enum RpcType {
55 Request = 1,
56 Response = 2,
57 ResponseCancel = 3,
58 ResponseFailed = 4,
59 ResponseNotImplemented = 5,
60 Disconnect = 6,
61 DescriptorRequest = 7,
62 DescriptorResponse = 8,
63 RequestCancel = 9,
64 StreamRequest = 10,
65}
66
67#[derive(Clone, PartialEq, Message)]
68pub struct Ack {
69 #[prost(bool, required, tag = "1")]
70 pub success: bool,
71}
72
73#[derive(Clone, PartialEq, Message)]
74pub struct DmxData {
75 #[prost(int32, required, tag = "1")]
76 pub universe: i32,
77 #[prost(bytes, required, tag = "2")]
78 pub data: Vec<u8>,
79 #[prost(int32, optional, tag = "3")]
80 pub priority: Option<i32>,
81}
82
83#[derive(Clone, PartialEq, Message)]
84pub struct UniverseRequest {
85 #[prost(int32, required, tag = "1")]
86 pub universe: i32,
87}
88
89#[derive(Debug, Clone)]
90pub struct OlaConfig {
91 pub host: String,
92 pub port: u16,
93 pub connect_timeout: Duration,
94 pub read_timeout: Option<Duration>,
95 pub write_timeout: Option<Duration>,
96}
97
98impl Default for OlaConfig {
99 fn default() -> Self {
100 Self {
101 host: "127.0.0.1".to_string(),
102 port: DEFAULT_OLA_PORT,
103 connect_timeout: Duration::from_secs(2),
104 read_timeout: Some(Duration::from_secs(2)),
105 write_timeout: Some(Duration::from_secs(2)),
106 }
107 }
108}
109
110pub struct OlaClient {
111 stream: TcpStream,
112 next_id: u32,
113}
114
115impl OlaClient {
116 pub fn connect(config: OlaConfig) -> Result<Self> {
117 let addr = (config.host.as_str(), config.port)
118 .to_socket_addrs()?
119 .next()
120 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "no address resolved"))?;
121 let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
122 stream.set_nodelay(true)?;
123 stream.set_read_timeout(config.read_timeout)?;
124 stream.set_write_timeout(config.write_timeout)?;
125 Ok(Self { stream, next_id: 0 })
126 }
127
128 pub fn connect_default() -> Result<Self> {
129 Self::connect(OlaConfig::default())
130 }
131
132 pub fn update_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<Ack> {
133 validate_dmx(data)?;
134 let request = DmxData {
135 universe,
136 data: data.to_vec(),
137 priority,
138 };
139 self.request("UpdateDmxData", &request)
140 }
141
142 pub fn stream_dmx(&mut self, universe: i32, data: &[u8], priority: Option<i32>) -> Result<()> {
143 validate_dmx(data)?;
144 let request = DmxData {
145 universe,
146 data: data.to_vec(),
147 priority,
148 };
149 self.stream_request("StreamDmxData", &request)
150 }
151
152 pub fn get_dmx(&mut self, universe: i32) -> Result<DmxData> {
153 let request = UniverseRequest { universe };
154 self.request("GetDmx", &request)
155 }
156
157 pub fn blackout(&mut self, universe: i32) -> Result<Ack> {
158 self.update_dmx(universe, &[0; 512], None)
159 }
160
161 pub fn stream_blackout(&mut self, universe: i32) -> Result<()> {
162 self.stream_dmx(universe, &[0; 512], None)
163 }
164
165 fn request<M, R>(&mut self, name: &str, message: &M) -> Result<R>
166 where
167 M: Message,
168 R: Message + Default,
169 {
170 let id = self.next_request_id();
171 let wrapper = RpcMessage {
172 r#type: RpcType::Request as i32,
173 id: Some(id),
174 name: Some(name.to_string()),
175 buffer: Some(message.encode_to_vec()),
176 };
177 self.write_wrapper(&wrapper)?;
178 let response = self.read_wrapper()?;
179 self.decode_response(id, response)
180 }
181
182 fn stream_request<M>(&mut self, name: &str, message: &M) -> Result<()>
183 where
184 M: Message,
185 {
186 let id = self.next_request_id();
187 let wrapper = RpcMessage {
188 r#type: RpcType::StreamRequest as i32,
189 id: Some(id),
190 name: Some(name.to_string()),
191 buffer: Some(message.encode_to_vec()),
192 };
193 self.write_wrapper(&wrapper)
194 }
195
196 fn next_request_id(&mut self) -> u32 {
197 self.next_id = if self.next_id == i32::MAX as u32 { 1 } else { self.next_id + 1 };
198 self.next_id
199 }
200
201 fn write_wrapper(&mut self, wrapper: &RpcMessage) -> Result<()> {
202 let body = wrapper.encode_to_vec();
203 let header = build_header(body.len())?.to_ne_bytes();
204 self.stream.write_all(&header)?;
205 self.stream.write_all(&body)?;
206 self.stream.flush()?;
207 Ok(())
208 }
209
210 fn read_wrapper(&mut self) -> Result<RpcMessage> {
211 let mut header = [0u8; 4];
212 self.stream.read_exact(&mut header)?;
213 let len = parse_header(u32::from_ne_bytes(header))?;
214 let mut body = vec![0u8; len];
215 self.stream.read_exact(&mut body)?;
216 Ok(RpcMessage::decode(body.as_slice())?)
217 }
218
219 fn decode_response<R>(&self, expected_id: u32, response: RpcMessage) -> Result<R>
220 where
221 R: Message + Default,
222 {
223 let actual_id = response.id.unwrap_or_default();
224 if actual_id != expected_id {
225 return Err(OlaError::ResponseIdMismatch {
226 expected: expected_id,
227 actual: actual_id,
228 });
229 }
230
231 match response.r#type {
232 x if x == RpcType::Response as i32 => {
233 let buffer = response.buffer.unwrap_or_default();
234 Ok(R::decode(buffer.as_slice())?)
235 }
236 x if x == RpcType::ResponseFailed as i32 => {
237 let buffer = response.buffer.unwrap_or_default();
238 let message = String::from_utf8_lossy(&buffer).to_string();
239 Err(OlaError::RpcFailed(message))
240 }
241 other => Err(OlaError::UnexpectedResponseType(other)),
242 }
243 }
244}
245
246fn validate_dmx(data: &[u8]) -> Result<()> {
247 if data.len() > 512 {
248 return Err(OlaError::DmxFrameTooLong(data.len()));
249 }
250 Ok(())
251}
252
253fn build_header(length: usize) -> Result<u32> {
254 let length = u32::try_from(length).map_err(|_| OlaError::DmxFrameTooLong(length))?;
255 Ok(((PROTOCOL_VERSION << 28) & VERSION_MASK) | (length & SIZE_MASK))
256}
257
258fn parse_header(header: u32) -> Result<usize> {
259 let version = (header & VERSION_MASK) >> 28;
260 if version != PROTOCOL_VERSION {
261 return Err(OlaError::UnsupportedProtocolVersion(version));
262 }
263 Ok((header & SIZE_MASK) as usize)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn header_roundtrip() {
272 let header = build_header(1234).unwrap();
273 assert_eq!(parse_header(header).unwrap(), 1234);
274 }
275
276 #[test]
277 fn rejects_oversized_dmx() {
278 let data = vec![0u8; 513];
279 assert!(matches!(validate_dmx(&data), Err(OlaError::DmxFrameTooLong(513))));
280 }
281
282 #[test]
283 fn dmx_data_encodes() {
284 let data = DmxData {
285 universe: 1,
286 data: vec![1, 2, 3],
287 priority: Some(100),
288 };
289 let encoded = data.encode_to_vec();
290 let decoded = DmxData::decode(encoded.as_slice()).unwrap();
291 assert_eq!(decoded.universe, 1);
292 assert_eq!(decoded.data, vec![1, 2, 3]);
293 assert_eq!(decoded.priority, Some(100));
294 }
295}