#![warn(missing_docs)]
use crate::error::errors::*;
use crate::packet::*;
use std::cell::RefCell;
use std::cmp;
use std::cmp::min;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{Arc, Mutex, MutexGuard};
use std::thread;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use socket2::{Domain, Socket, Type};
use uuid::Uuid;
const SND_UPDATE_THREAD_NAME: &str = "rust_sacn_snd_update_thread";
const DEFAULT_TERMINATE_START_CODE: u8 = 0;
const DEFAULT_POLL_PERIOD: Duration = Duration::from_secs(1);
#[derive(Debug)]
pub struct SacnSource {
internal: Arc<Mutex<SacnSourceInternal>>,
update_thread: Option<JoinHandle<()>>,
}
#[derive(Debug)]
struct SacnSourceInternal {
socket: Socket,
addr: SocketAddr,
cid: Uuid,
name: String,
preview_data: bool,
data_sequences: RefCell<HashMap<u16, u8>>,
sync_sequences: RefCell<HashMap<u16, u8>>,
universes: Vec<u16>,
running: bool,
last_discovery_advert_timestamp: Instant,
is_sending_discovery: bool,
}
impl SacnSource {
pub fn new_v4(name: &str) -> Result<SacnSource> {
let cid = Uuid::new_v4();
SacnSource::with_cid_v4(name, cid)
}
pub fn with_cid_v4(name: &str, cid: Uuid) -> Result<SacnSource> {
let ip = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
ACN_SDT_MULTICAST_PORT,
);
SacnSource::with_cid_ip(name, cid, ip)
}
pub fn new_v6(name: &str) -> Result<SacnSource> {
let cid = Uuid::new_v4();
SacnSource::with_cid_v6(name, cid)
}
pub fn with_cid_v6(name: &str, cid: Uuid) -> Result<SacnSource> {
let ip = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
ACN_SDT_MULTICAST_PORT,
);
SacnSource::with_cid_ip(name, cid, ip)
}
pub fn with_ip(name: &str, ip: SocketAddr) -> Result<SacnSource> {
SacnSource::with_cid_ip(name, Uuid::new_v4(), ip)
}
pub fn with_cid_ip(name: &str, cid: Uuid, ip: SocketAddr) -> Result<SacnSource> {
if name.len() > E131_SOURCE_NAME_FIELD_LENGTH {
return Err(SacnError::MalformedSourceName(
"Source name provided is longer than maximum allowed".to_string(),
));
}
let trd_builder = thread::Builder::new().name(SND_UPDATE_THREAD_NAME.into());
let internal_src = Arc::new(Mutex::new(SacnSourceInternal::with_cid_ip(name, cid, ip)?));
let mut trd_src = internal_src.clone();
let src = SacnSource {
internal: internal_src,
update_thread: Some(trd_builder.spawn(move || {
while trd_src.lock().unwrap().running {
thread::sleep(DEFAULT_POLL_PERIOD);
match perform_periodic_update(&mut trd_src) {
Err(e) => {
println!("Periodic error: {e:?}");
}
_ => {
}
}
}
})?),
};
Ok(src)
}
pub fn register_universes(&mut self, universes: &[u16]) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.register_universes(universes)
}
pub fn register_universe(&mut self, universe: u16) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.register_universe(universe)
}
pub fn send(
&mut self,
universes: &[u16],
data: &[u8],
priority: Option<u8>,
dst_ip: Option<SocketAddr>,
synchronisation_addr: Option<u16>,
) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.send(
universes,
data,
priority,
dst_ip,
synchronisation_addr,
)
}
pub fn send_sync_packet(&mut self, universe: u16, dst_ip: Option<SocketAddr>) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.send_sync_packet(universe, dst_ip)
}
pub fn terminate_stream(&mut self, universe: u16, start_code: u8) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.terminate_stream(universe, start_code)
}
pub fn cid(&self) -> Result<Uuid> {
Ok(*unlock_internal(&self.internal)?.cid())
}
pub fn set_cid(&mut self, cid: Uuid) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_cid(cid);
Ok(())
}
pub fn name(&self) -> Result<String> {
Ok(unlock_internal(&self.internal)?.name().into())
}
pub fn set_name(&mut self, name: &str) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_name(name)
}
pub fn preview_mode(&self) -> Result<bool> {
Ok(unlock_internal(&self.internal)?.preview_mode())
}
pub fn set_preview_mode(&mut self, preview_mode: bool) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_preview_mode(preview_mode);
Ok(())
}
pub fn set_is_sending_discovery(&mut self, val: bool) {
self.internal.lock().unwrap().set_is_sending_discovery(val);
}
pub fn multicast_ttl(&self) -> Result<u32> {
unlock_internal(&self.internal)?.multicast_ttl()
}
pub fn set_multicast_ttl(&mut self, multicast_ttl: u32) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_multicast_ttl(multicast_ttl)
}
pub fn ttl(&self) -> Result<u32> {
unlock_internal(&self.internal)?.ttl()
}
pub fn set_ttl(&mut self, ttl: u32) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_ttl(ttl)
}
pub fn set_multicast_loop_v4(&mut self, multicast_loop: bool) -> Result<()> {
unlock_internal_mut(&mut self.internal)?.set_multicast_loop_v4(multicast_loop)
}
pub fn multicast_loop(&self) -> Result<bool> {
unlock_internal(&self.internal)?.multicast_loop()
}
pub fn universes(&self) -> Result<Vec<u16>> {
Ok(unlock_internal(&self.internal)?.universes())
}
}
impl Drop for SacnSource {
fn drop(&mut self) {
match unlock_internal_mut(&mut self.internal) {
Ok(mut i) => {
i.running = false;
}
Err(_) => {
return;
} };
if let Some(thread) = self.update_thread.take() {
if let Ok(mut i) = unlock_internal_mut(&mut self.internal) {
let _ = i.terminate(DEFAULT_TERMINATE_START_CODE);
{} } else {
{} };
thread.join().unwrap();
}
}
}
impl SacnSourceInternal {
fn with_cid_ip(name: &str, cid: Uuid, ip: SocketAddr) -> Result<SacnSourceInternal> {
let socket = if ip.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, None).unwrap()
} else if ip.is_ipv6() {
Socket::new(Domain::IPV6, Type::DGRAM, None).unwrap()
} else {
return Err(SacnError::UnsupportedIpVersion(
"Address to create SacnSource is not IPv4 or IPv6".to_string(),
));
};
#[cfg(target_os = "linux")]
socket.set_reuse_port(true)?;
socket.set_reuse_address(true)?;
socket.bind(&ip.into())?;
let ds = SacnSourceInternal {
socket,
addr: ip,
cid,
name: name.to_string(),
preview_data: false,
data_sequences: RefCell::new(HashMap::new()),
sync_sequences: RefCell::new(HashMap::new()),
universes: Vec::new(),
running: true,
last_discovery_advert_timestamp: Instant::now(),
is_sending_discovery: true,
};
Ok(ds)
}
fn set_is_sending_discovery(&mut self, val: bool) {
self.is_sending_discovery = val;
}
fn register_universes(&mut self, universes: &[u16]) -> Result<()> {
for u in universes {
self.register_universe(*u)?;
}
Ok(())
}
fn register_universe(&mut self, universe: u16) -> Result<()> {
is_universe_in_range(universe)?;
if self.universes.is_empty() {
self.universes.push(universe);
} else {
match self.universes.binary_search(&universe) {
Err(i) => {
self.universes.insert(i, universe);
}
Ok(_) => {
}
}
}
Ok(())
}
fn deregister_universe(&mut self, universe: u16) -> Result<()> {
is_universe_in_range(universe)?;
match self.universes.binary_search(&universe) {
Err(_i) => {
Err(SacnError::UniverseNotFound(universe))
}
Ok(i) => {
self.universes.remove(i);
Ok(())
}
}
}
fn universe_allowed(&self, u: &u16) -> Result<()> {
is_universe_in_range(*u)?;
if !self.universes.contains(u) {
return Err(SacnError::UniverseNotRegistered(*u));
}
Ok(())
}
fn send(
&self,
universes: &[u16],
data: &[u8],
priority: Option<u8>,
dst_ip: Option<SocketAddr>,
synchronisation_addr: Option<u16>,
) -> Result<()> {
if !self.running {
return Err(SacnError::SenderAlreadyTerminated(
"Attempted to send".to_string(),
));
}
if data.is_empty() {
return Err(SacnError::DataArrayEmpty());
}
for u in universes {
self.universe_allowed(u)?;
}
if synchronisation_addr.is_some() {
self.universe_allowed(&synchronisation_addr.unwrap())
.map_err(|_| SacnError::IllegalSyncUniverse(synchronisation_addr.unwrap()))?;
}
let required_universes =
(data.len() as f64 / UNIVERSE_CHANNEL_CAPACITY as f64).ceil() as usize;
if universes.len() < required_universes {
return Err(SacnError::UniverseListEmpty());
}
for (i, &universe) in universes.iter().enumerate().take(required_universes) {
let start_index = i * UNIVERSE_CHANNEL_CAPACITY;
let end_index = cmp::min((i + 1) * UNIVERSE_CHANNEL_CAPACITY, data.len());
self.send_universe(
universe,
&data[start_index..end_index],
priority.unwrap_or(E131_DEFAULT_PRIORITY),
&dst_ip,
synchronisation_addr.unwrap_or(NO_SYNC_UNIVERSE),
)?;
}
Ok(())
}
fn send_universe(
&self,
universe: u16,
data: &[u8],
priority: u8,
dst_ip: &Option<SocketAddr>,
sync_address: u16,
) -> Result<()> {
if priority > E131_MAX_PRIORITY {
return Err(SacnError::InvalidPriority(priority));
}
if data.len() > UNIVERSE_CHANNEL_CAPACITY {
return Err(SacnError::ExceedUniverseCapacity(data.len()));
}
let mut sequence = match self.data_sequences.borrow().get(&universe) {
Some(s) => *s,
None => STARTING_SEQUENCE_NUMBER,
};
let packet = AcnRootLayerProtocol {
pdu: E131RootLayer {
cid: self.cid,
data: E131RootLayerData::DataPacket(DataPacketFramingLayer {
source_name: self.name.as_str().into(),
priority,
synchronization_address: sync_address,
sequence_number: sequence,
preview_data: self.preview_data,
stream_terminated: false,
force_synchronization: false,
universe,
data: DataPacketDmpLayer {
property_values: {
let mut property_values = Vec::with_capacity(data.len());
property_values.extend(data);
property_values.into()
},
},
}),
},
};
if dst_ip.is_some() {
self.socket
.send_to(&packet.pack_alloc().unwrap(), &dst_ip.unwrap().into())
.map_err(|e| {
std::io::Error::new(e.kind(), "Failed to send data unicast on socket")
})?;
} else {
let dst = if self.addr.is_ipv6() {
universe_to_ipv6_multicast_addr(universe)?
} else {
universe_to_ipv4_multicast_addr(universe)?
};
self.socket
.send_to(&packet.pack_alloc().unwrap(), &dst)
.map_err(|e| {
std::io::Error::new(e.kind(), "Failed to send data multicast on socket")
})?;
}
if sequence == 255 {
sequence = 0;
} else {
sequence += 1;
}
self.data_sequences.borrow_mut().insert(universe, sequence);
Ok(())
}
fn send_sync_packet(&self, universe: u16, dst_ip: Option<SocketAddr>) -> Result<()> {
self.universe_allowed(&universe)?;
let ip = if let Some(dst) = dst_ip {
dst.into()
} else if self.addr.is_ipv6() {
universe_to_ipv6_multicast_addr(universe)?
} else {
universe_to_ipv4_multicast_addr(universe)?
};
let mut sequence = match self.sync_sequences.borrow().get(&universe) {
Some(s) => *s,
None => STARTING_SEQUENCE_NUMBER,
};
let packet = AcnRootLayerProtocol {
pdu: E131RootLayer {
cid: self.cid,
data: E131RootLayerData::SynchronizationPacket(SynchronizationPacketFramingLayer {
sequence_number: sequence,
synchronization_address: universe,
}),
},
};
self.socket
.send_to(&packet.pack_alloc()?, &ip)
.map_err(|e| std::io::Error::new(e.kind(), "Failed to send sync packet on socket"))?;
if sequence == 255 {
sequence = 0;
} else {
sequence += 1;
}
self.sync_sequences.borrow_mut().insert(universe, sequence);
Ok(())
}
fn send_terminate_stream_pkt(
&self,
universe: u16,
dst_ip: Option<SocketAddr>,
start_code: u8,
) -> Result<()> {
self.universe_allowed(&universe)?;
let ip = match dst_ip {
Some(x) => x.into(),
None => {
if self.addr.is_ipv6() {
universe_to_ipv6_multicast_addr(universe)?
} else {
universe_to_ipv4_multicast_addr(universe)?
}
}
};
let mut sequence = match self.data_sequences.borrow_mut().remove(&universe) {
Some(s) => s,
None => STARTING_SEQUENCE_NUMBER,
};
let packet = AcnRootLayerProtocol {
pdu: E131RootLayer {
cid: self.cid,
data: E131RootLayerData::DataPacket(DataPacketFramingLayer {
source_name: self.name.as_str().into(),
priority: 100,
synchronization_address: 0,
sequence_number: sequence,
preview_data: self.preview_data,
stream_terminated: true,
force_synchronization: false,
universe,
data: DataPacketDmpLayer {
property_values: vec![start_code].into(),
},
}),
},
};
let res = &packet.pack_alloc().unwrap();
self.socket.send_to(res, &ip)?;
if sequence == 255 {
sequence = 0;
} else {
sequence += 1;
}
self.data_sequences.borrow_mut().insert(universe, sequence);
Ok(())
}
fn terminate_stream(&mut self, universe: u16, start_code: u8) -> Result<()> {
for _ in 0..E131_TERMINATE_STREAM_PACKET_COUNT {
self.send_terminate_stream_pkt(universe, None, start_code)?;
}
self.deregister_universe(universe)?;
Ok(())
}
fn terminate(&mut self, start_code: u8) -> Result<()> {
self.running = false;
let universes = self.universes.clone(); for u in universes {
self.terminate_stream(u, start_code)?;
}
Ok(())
}
fn send_universe_discovery(&self) -> Result<()> {
let pages_req: u8 = ((self.universes.len() / DISCOVERY_UNI_PER_PAGE) + 1) as u8;
for p in 0..pages_req {
let start_index = (p as usize) * DISCOVERY_UNI_PER_PAGE;
let end_index = min(
((p as usize) + 1) * DISCOVERY_UNI_PER_PAGE,
self.universes.len(),
);
self.send_universe_discovery_detailed(
p,
pages_req - 1,
&self.universes[start_index..end_index],
)?;
}
Ok(())
}
fn send_universe_discovery_detailed(
&self,
page: u8,
last_page: u8,
universes: &[u16],
) -> Result<()> {
let packet = AcnRootLayerProtocol {
pdu: E131RootLayer {
cid: self.cid,
data: E131RootLayerData::UniverseDiscoveryPacket(
UniverseDiscoveryPacketFramingLayer {
source_name: self.name.as_str().into(),
data: UniverseDiscoveryPacketUniverseDiscoveryLayer {
page,
last_page,
universes: universes.into(),
},
},
),
},
};
let ip = if self.addr.is_ipv6() {
universe_to_ipv6_multicast_addr(E131_DISCOVERY_UNIVERSE)?
} else {
universe_to_ipv4_multicast_addr(E131_DISCOVERY_UNIVERSE)?
};
self.socket.send_to(&packet.pack_alloc()?, &ip)?;
Ok(())
}
fn cid(&self) -> &Uuid {
&self.cid
}
fn set_cid(&mut self, cid: Uuid) {
self.cid = cid;
}
fn name(&self) -> &str {
&self.name
}
fn set_name(&mut self, name: &str) -> Result<()> {
if name.len() > E131_SOURCE_NAME_FIELD_LENGTH {
return Err(SacnError::MalformedSourceName(
"Source name provided is longer than maximum allowed".to_string(),
));
}
self.name = name.to_string();
Ok(())
}
fn preview_mode(&self) -> bool {
self.preview_data
}
fn set_preview_mode(&mut self, preview_mode: bool) {
self.preview_data = preview_mode;
}
fn set_multicast_ttl(&self, multicast_ttl: u32) -> Result<()> {
Ok(self.socket.set_multicast_ttl_v4(multicast_ttl)?)
}
fn ttl(&self) -> Result<u32> {
Ok(self.socket.ttl_v4()?)
}
fn set_ttl(&mut self, ttl: u32) -> Result<()> {
Ok(self.socket.set_ttl_v4(ttl)?)
}
fn multicast_ttl(&self) -> Result<u32> {
Ok(self.socket.multicast_ttl_v4()?)
}
fn set_multicast_loop_v4(&self, multicast_loop: bool) -> Result<()> {
Ok(self.socket.set_multicast_loop_v4(multicast_loop)?)
}
fn multicast_loop(&self) -> Result<bool> {
Ok(self.socket.multicast_loop_v4()?)
}
pub fn universes(&self) -> Vec<u16> {
self.universes.clone()
}
}
fn unlock_internal(
internal: &Arc<Mutex<SacnSourceInternal>>,
) -> Result<MutexGuard<'_, SacnSourceInternal>> {
match internal.lock() {
Err(_) => {
Err(SacnError::SourceCorrupt("Mutex poisoned".to_string()))
}
Ok(lock) => Ok(lock),
}
}
fn unlock_internal_mut(
internal: &mut Arc<Mutex<SacnSourceInternal>>,
) -> Result<MutexGuard<'_, SacnSourceInternal>> {
match internal.lock() {
Err(_) => {
Err(SacnError::SourceCorrupt("Mutex poisoned".to_string()))
}
Ok(lock) => Ok(lock),
}
}
fn perform_periodic_update(src: &mut Arc<Mutex<SacnSourceInternal>>) -> Result<()> {
let mut unwrap_src = unlock_internal_mut(src)?;
if unwrap_src.is_sending_discovery
&& Instant::now().duration_since(unwrap_src.last_discovery_advert_timestamp)
> E131_UNIVERSE_DISCOVERY_INTERVAL
{
unwrap_src.send_universe_discovery()?;
unwrap_src.last_discovery_advert_timestamp = Instant::now();
}
Ok(())
}