mod fetch;
mod mdns;
mod oscquery;
pub use oscquery::models;
pub use rosc;
pub use fetch::Error as FetchError;
pub use mdns::Error as MdnsError;
pub use oscquery::Error as OscQueryError;
use crate::fetch::fetch;
use crate::oscquery::OscQuery;
use futures::{stream, StreamExt};
use hickory_proto::rr::Name;
use oscquery::models::{HostInfo, OscNode, OscRootNode};
use rosc::OscPacket;
use std::str::FromStr;
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
};
use tokio::{
net::UdpSocket,
sync::{mpsc, RwLock},
task::JoinHandle,
};
use wildmatch::WildMatch;
const OSC_PACKET_BUFFER_SIZE: usize = 65535;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("OSC error: {0}")]
OscError(#[from] rosc::OscError),
#[error("OSCQuery error: {0}")]
OscQueryError(#[from] oscquery::Error),
#[error("mDNS error: {0}")]
MdnsError(#[from] mdns::Error),
#[error("Hickory DNS protocol error: {0}")]
HickoryError(#[from] hickory_proto::ProtoError),
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("Fetch error: {0}")]
FetchError(#[from] fetch::Error),
#[error("No valid network interface found: no non-loopback IPv4 address available")]
NoValidInterface,
}
struct ServiceHandle {
osc: JoinHandle<()>,
osc_query: OscQuery,
}
pub enum ServiceType {
Osc(String, SocketAddr),
OscQuery(String, SocketAddr),
}
pub struct VRChatOSC {
send_socket: UdpSocket,
osc_ip: Arc<RwLock<IpAddr>>,
mdns: mdns::Mdns,
service_handles: Arc<RwLock<HashMap<String, ServiceHandle>>>,
on_service_discovered_callback:
Arc<RwLock<Option<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>>>,
}
fn find_non_loopback_ipv4() -> Option<IpAddr> {
if let Ok(interfaces) = if_addrs::get_if_addrs() {
for iface in interfaces {
if let std::net::IpAddr::V4(ipv4) = iface.addr.ip() {
if !ipv4.is_loopback() {
return Some(IpAddr::V4(ipv4));
}
}
}
}
None
}
fn find_local_ip_for_destination(dest_ip: IpAddr) -> Result<IpAddr, Error> {
if dest_ip.is_loopback() {
return Ok(dest_ip);
}
if let Ok(interfaces) = if_addrs::get_if_addrs() {
for iface in interfaces {
if iface.addr.ip() == dest_ip {
return Ok(dest_ip);
}
}
}
let socket = match dest_ip {
IpAddr::V4(_) => std::net::UdpSocket::bind("0.0.0.0:0")?,
IpAddr::V6(_) => std::net::UdpSocket::bind("[::]:0")?,
};
socket.connect((dest_ip, 0))?;
Ok(socket.local_addr()?.ip())
}
fn sanitize_service_name(name: &str) -> String {
name.chars()
.map(|c| {
if (c.is_ascii() && !c.is_ascii_alphanumeric()) || c.is_control() {
'-'
} else if c.is_ascii_uppercase() {
c.to_ascii_lowercase()
} else {
c
}
})
.collect()
}
impl VRChatOSC {
pub async fn new(osc_ip: Option<IpAddr>) -> Result<Arc<VRChatOSC>, Error> {
let osc_ip = match osc_ip {
Some(ip) => ip,
None => find_non_loopback_ipv4().ok_or(Error::NoValidInterface)?,
};
let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await?;
let (discover_notifier_tx, mut discover_notifier_rx) = mpsc::channel(8);
let advertised_ip = find_local_ip_for_destination(osc_ip)?;
let mdns_client = mdns::Mdns::new(discover_notifier_tx, advertised_ip).await?;
let _ = mdns_client
.follow(Name::from_ascii("_osc._udp.local.")?)
.await;
let _ = mdns_client
.follow(Name::from_ascii("_oscjson._tcp.local.")?)
.await;
let on_service_discovered_callback = Arc::new(RwLock::new(
None::<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>,
));
let callback_arc_clone = on_service_discovered_callback.clone();
tokio::spawn(async move {
loop {
if let Some((service_name, socket_addr)) = discover_notifier_rx.recv().await {
let callback_guard = callback_arc_clone.read().await;
if let Some(callback) = callback_guard.as_ref() {
if service_name.trim_to(3).to_utf8() == "_osc._udp.local." {
callback(ServiceType::Osc(service_name.to_utf8(), socket_addr));
} else if service_name.trim_to(3).to_utf8() == "_oscjson._tcp.local." {
callback(ServiceType::OscQuery(service_name.to_utf8(), socket_addr));
}
}
}
}
});
Ok(Arc::new(VRChatOSC {
send_socket: socket,
osc_ip: Arc::new(RwLock::new(osc_ip)),
mdns: mdns_client,
service_handles: Arc::new(RwLock::new(HashMap::new())),
on_service_discovered_callback,
}))
}
pub async fn set_osc_ip(&self, ip: Option<IpAddr>) -> Result<(), Error> {
let ip = match ip {
Some(ip) => ip,
None => find_non_loopback_ipv4().ok_or(Error::NoValidInterface)?,
};
*self.osc_ip.write().await = ip;
let advertised_ip = find_local_ip_for_destination(ip)?;
self.mdns.set_advertised_ip(advertised_ip).await;
Ok(())
}
pub async fn get_osc_ip(&self) -> IpAddr {
*self.osc_ip.read().await
}
pub async fn on_connect<F>(&self, callback: F)
where
F: Fn(ServiceType) + Send + Sync + 'static,
{
let mut callback_guard = self.on_service_discovered_callback.write().await;
*callback_guard = Some(Arc::new(callback));
}
pub async fn register<F>(
&self,
service_name: &str,
parameters: OscRootNode,
handler: F,
) -> Result<(), Error>
where
F: Fn(OscPacket) + Send + 'static,
{
let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await?;
let osc_local_addr = socket.local_addr()?;
let osc_handle = tokio::spawn(async move {
let mut buf = [0; OSC_PACKET_BUFFER_SIZE]; loop {
match socket.recv_from(&mut buf).await {
Ok((len, addr)) => {
if let Ok((_, packet)) = rosc::decoder::decode_udp(&buf[..len]) {
handler(packet); } else {
log::debug!("Failed to decode OSC packet from {}", addr);
}
}
Err(e) => {
if e.kind() == std::io::ErrorKind::ConnectionReset
|| e.kind() == std::io::ErrorKind::BrokenPipe
{
log::warn!("Socket connection error ({}). Task for {:?} might need to be restarted or interface is down.", e, socket.local_addr().ok());
break;
} else {
log::warn!(
"Failed to receive data on OSC socket {:?}: {}",
socket.local_addr().ok(),
e
);
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
continue;
}
}
}
});
let host_info = HostInfo::new(
service_name.to_string(),
osc_local_addr.ip(), osc_local_addr.port(), );
let mut osc_query = OscQuery::new(host_info, parameters);
let osc_query_local_addr = osc_query
.serve(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
.await?;
let sanitized_service_name = sanitize_service_name(service_name);
self.mdns
.register(
Name::from_str(&format!("{}._osc._udp.local.", sanitized_service_name))?,
osc_local_addr.port(),
)
.await?;
self.mdns
.register(
Name::from_str(&format!("{}._oscjson._tcp.local.", sanitized_service_name))?,
osc_query_local_addr.port(),
)
.await?;
let mut handles = self.service_handles.write().await;
handles.insert(
service_name.to_string(),
ServiceHandle {
osc: osc_handle,
osc_query,
},
);
Ok(())
}
pub async fn unregister(&self, service_name: &str) -> Result<(), Error> {
let sanitized_service_name = sanitize_service_name(service_name);
let mut service_handles_map = self.service_handles.write().await;
if let Some(mut service_handle_entry) = service_handles_map.remove(service_name) {
self.mdns
.unregister(Name::from_str(&format!(
"{}._osc._udp.local.",
sanitized_service_name
))?)
.await?;
self.mdns
.unregister(Name::from_str(&format!(
"{}._oscjson._tcp.local.",
sanitized_service_name
))?)
.await?;
service_handle_entry.osc.abort(); service_handle_entry.osc_query.shutdown(); }
Ok(())
}
pub async fn send(&self, packet: OscPacket, to: &str) -> Result<(), Error> {
let services = self
.mdns
.find_service(|name, _| {
WildMatch::new(&format!("{}._osc._udp.local.", to)).matches(&name.to_utf8())
})
.await;
if services.is_empty() {
log::info!("No mDNS services found matching the expression: {}", to);
return Ok(());
}
let msg_buf = rosc::encoder::encode(&packet)?;
let send_futs = services
.into_iter()
.map(|(_, addr)| self.send_socket.send_to(&msg_buf, addr));
let results = futures::future::join_all(send_futs).await;
for res in results {
res?;
}
Ok(())
}
pub async fn send_to_addr(&self, packet: OscPacket, addr: SocketAddr) -> Result<(), Error> {
let msg_buf = rosc::encoder::encode(&packet)?;
self.send_socket.send_to(&msg_buf, addr).await?;
Ok(())
}
pub async fn get_parameter(
&self,
method: &str,
from: &str,
) -> Result<Vec<(String, OscNode)>, Error> {
let services = self
.mdns
.find_service(|name, _| {
WildMatch::new(&format!("{}._oscjson._tcp.local.", from)).matches(&name.to_utf8())
})
.await;
if services.is_empty() {
log::info!(
"No mDNS services found for get_parameter matching expression: {}",
from
);
return Ok(Vec::new());
}
let params = stream::iter(services)
.map(|(name, addr)| async move {
fetch::<_, OscNode>(addr, method)
.await
.map(|(param, _)| (name.to_utf8(), param))
})
.buffer_unordered(3)
.filter_map(|res| async {
if let Err(e) = &res {
log::warn!("Failed to fetch parameter: {:?}", e);
}
res.ok()
})
.collect::<Vec<_>>()
.await;
Ok(params)
}
pub async fn get_parameter_from_addr(
&self,
method: &str,
addr: SocketAddr,
) -> Result<OscNode, Error> {
let (param, _url) = fetch::<_, OscNode>(addr, method).await?;
Ok(param)
}
pub async fn shutdown(&self) -> Result<(), Error> {
let mut service_handles_map = self.service_handles.write().await;
let service_names: Vec<String> = service_handles_map.keys().cloned().collect();
for name in service_names {
if let Some(mut handle) = service_handles_map.remove(&name) {
let sanitized_service_name = sanitize_service_name(&name);
if let Err(e) = self
.mdns
.unregister(Name::from_str(&format!(
"{}._osc._udp.local.",
sanitized_service_name
))?)
.await
{
log::error!("Failed to unregister OSC for {}: {}", name, e);
}
if let Err(e) = self
.mdns
.unregister(Name::from_str(&format!(
"{}._oscjson._tcp.local.",
sanitized_service_name
))?)
.await
{
log::error!("Failed to unregister OSCQuery for {}: {}", name, e);
}
handle.osc.abort();
handle.osc_query.shutdown();
}
}
Ok(())
}
pub async fn list_services(&self) -> Vec<String> {
let handles = self.service_handles.read().await;
handles.keys().cloned().collect()
}
}
impl Drop for VRChatOSC {
fn drop(&mut self) {
if let Ok(mut handles) = self.service_handles.try_write() {
let service_names: Vec<String> = handles.keys().cloned().collect();
for name in service_names {
if let Some(mut service_handle) = handles.remove(&name) {
service_handle.osc.abort();
service_handle.osc_query.shutdown();
}
}
} else {
if !std::thread::panicking() {
log::warn!("VRChatOSC: Could not acquire lock on service_handles during drop. Explicitly call shutdown() for robust cleanup.");
}
}
}
}