#[cfg(feature = "events")]
use std::sync::Weak;
use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use futures_util::{
sink::SinkExt,
stream::{SplitSink, Stream, StreamExt},
};
use semver::{Comparator, Op, Prerelease};
use serde::de::DeserializeOwned;
#[cfg(feature = "events")]
use tokio::sync::broadcast;
use tokio::{net::TcpStream, sync::Mutex, task::JoinHandle};
use tokio_websockets::{MaybeTlsStream, Message, WebSocketStream};
use tracing::{debug, error, info, trace, warn};
use self::connection::{ReceiverList, ReidentifyReceiverList};
pub use self::{
canvases::Canvases,
config::Config,
connection::{HandshakeError, ReceiveError},
filters::Filters,
general::General,
hotkeys::Hotkeys,
inputs::Inputs,
media_inputs::MediaInputs,
outputs::Outputs,
profiles::Profiles,
recording::Recording,
replay_buffer::ReplayBuffer,
scene_collections::SceneCollections,
scene_items::SceneItems,
scenes::Scenes,
sources::Sources,
streaming::Streaming,
transitions::Transitions,
ui::Ui,
virtual_cam::VirtualCam,
};
#[cfg(feature = "events")]
use crate::events::Event;
use crate::{
error::{Error, Result},
requests::{ClientRequest, EventSubscription, Reidentify, Request, RequestType},
responses::ServerMessage,
};
mod canvases;
mod config;
mod connection;
mod filters;
mod general;
mod hotkeys;
mod inputs;
mod media_inputs;
mod outputs;
mod profiles;
mod recording;
mod replay_buffer;
mod scene_collections;
mod scene_items;
mod scenes;
mod sources;
mod streaming;
mod transitions;
mod ui;
mod virtual_cam;
#[derive(Debug, thiserror::Error)]
enum InnerError {
#[error("websocket message not convertible to text")]
IntoText,
#[error("failed deserializing message")]
DeserializeMessage(#[source] serde_json::Error),
#[error("the request ID `{0}` is not an integer")]
InvalidRequestId(#[source] std::num::ParseIntError, String),
#[error("received unexpected server message: {0:?}")]
UnexpectedMessage(ServerMessage),
}
pub struct Client {
write: Mutex<MessageWriter>,
id_counter: AtomicU64,
receivers: Arc<ReceiverList>,
reidentify_receivers: Arc<ReidentifyReceiverList>,
#[cfg(feature = "events")]
event_sender: Weak<broadcast::Sender<Event>>,
handle: Option<JoinHandle<()>>,
dangerous: DangerousConnectConfig,
}
type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
pub const DEFAULT_BROADCAST_CAPACITY: usize = 100;
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
#[cfg_attr(feature = "builder", derive(bon::Builder))]
pub struct ConnectConfig<H, P>
where
H: AsRef<str>,
P: AsRef<str>,
{
#[cfg_attr(feature = "builder", builder(start_fn))]
pub host: H,
#[cfg_attr(feature = "builder", builder(start_fn))]
pub port: u16,
#[cfg_attr(feature = "builder", builder(field))]
pub dangerous: Option<DangerousConnectConfig>,
pub password: Option<P>,
pub event_subscriptions: Option<EventSubscription>,
#[cfg(feature = "tls")]
#[cfg_attr(feature = "builder", builder(default))]
pub tls: bool,
#[cfg_attr(not(feature = "events"), allow(dead_code))]
#[cfg_attr(feature = "builder", builder(default = DEFAULT_BROADCAST_CAPACITY))]
pub broadcast_capacity: usize,
#[cfg_attr(feature = "builder", builder(default = DEFAULT_CONNECT_TIMEOUT))]
pub connect_timeout: Duration,
}
#[cfg(feature = "builder")]
impl<H, P, S> ConnectConfigBuilder<H, P, S>
where
H: AsRef<str>,
P: AsRef<str>,
S: connect_config_builder::State,
{
pub fn dangerous<S2: dangerous_connect_config_builder::State>(
mut self,
f: impl FnOnce(DangerousConnectConfigBuilder) -> DangerousConnectConfigBuilder<S2>,
) -> Self {
self.dangerous = Some(f(DangerousConnectConfig::builder()).build());
self
}
}
#[derive(Default)]
#[cfg_attr(feature = "builder", derive(bon::Builder))]
pub struct DangerousConnectConfig {
#[cfg_attr(feature = "builder", builder(default))]
pub skip_studio_version_check: bool,
#[cfg_attr(feature = "builder", builder(default))]
pub skip_websocket_version_check: bool,
}
const OBS_STUDIO_VERSION: Comparator = Comparator {
op: Op::GreaterEq,
major: 30,
minor: Some(2),
patch: None,
pre: Prerelease::EMPTY,
};
const OBS_WEBSOCKET_VERSION: Comparator = Comparator {
op: Op::Caret,
major: 5,
minor: Some(5),
patch: None,
pre: Prerelease::EMPTY,
};
const RPC_VERSION: u32 = 1;
impl<H, P> ConnectConfig<H, P>
where
H: AsRef<str>,
P: AsRef<str>,
{
#[cfg(feature = "tls")]
fn tls(&self) -> bool {
self.tls
}
#[cfg(not(feature = "tls"))]
#[expect(clippy::unused_self)]
fn tls(&self) -> bool {
false
}
}
impl Client {
pub async fn connect(
host: impl AsRef<str>,
port: u16,
password: Option<impl AsRef<str>>,
) -> Result<Self> {
Self::connect_with_config(ConnectConfig {
host,
port,
password,
event_subscriptions: if cfg!(feature = "events") {
None
} else {
Some(EventSubscription::NONE)
},
#[cfg(feature = "tls")]
tls: false,
broadcast_capacity: DEFAULT_BROADCAST_CAPACITY,
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
dangerous: None,
})
.await
}
pub async fn connect_with_config<H, P>(config: ConnectConfig<H, P>) -> Result<Self>
where
H: AsRef<str>,
P: AsRef<str>,
{
if config.dangerous.is_some() {
warn!(
"dangerous configuration is being used. Please not that no support is given for \
any issues encountered while using these options"
);
}
let (socket, _) = tokio::time::timeout(
config.connect_timeout,
tokio_websockets::ClientBuilder::new()
.uri(&format!(
"{}://{}:{}",
if config.tls() { "wss" } else { "ws" },
config.host.as_ref(),
config.port
))
.map_err(crate::error::InvalidUriError)?
.connect(),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(crate::error::ConnectError)?;
let (mut write, mut read) = socket.split();
let receivers = Arc::new(ReceiverList::default());
let reidentify_receivers = Arc::new(ReidentifyReceiverList::default());
#[cfg(feature = "events")]
let (event_sender, _) = broadcast::channel(config.broadcast_capacity);
#[cfg(feature = "events")]
let event_sender = Arc::new(event_sender);
#[cfg(feature = "events")]
let events_tx = Arc::clone(&event_sender);
self::connection::handshake(
&mut write,
&mut read,
config.password.as_ref().map(AsRef::as_ref),
config.event_subscriptions,
)
.await?;
let handle = tokio::spawn(recv_loop(
read,
#[cfg(feature = "events")]
events_tx,
Arc::clone(&receivers),
Arc::clone(&reidentify_receivers),
));
let write = Mutex::new(write);
let id_counter = AtomicU64::new(1);
let client = Self {
write,
id_counter,
receivers,
reidentify_receivers,
#[cfg(feature = "events")]
event_sender: Arc::downgrade(&event_sender),
handle: Some(handle),
dangerous: config.dangerous.unwrap_or_default(),
};
client.verify_versions().await?;
Ok(client)
}
async fn verify_versions(&self) -> Result<()> {
let version = self.general().version().await?;
if !self.dangerous.skip_studio_version_check
&& !OBS_STUDIO_VERSION.matches(&version.obs_studio_version)
{
return Err(Error::ObsStudioVersion(
version.obs_studio_version,
OBS_STUDIO_VERSION,
));
}
if !self.dangerous.skip_websocket_version_check
&& !OBS_WEBSOCKET_VERSION.matches(&version.obs_web_socket_version)
{
return Err(Error::ObsWebsocketVersion(
version.obs_web_socket_version,
OBS_WEBSOCKET_VERSION,
));
}
if RPC_VERSION != version.rpc_version {
return Err(Error::RpcVersion {
requested: RPC_VERSION,
negotiated: version.rpc_version,
});
}
Ok(())
}
async fn send_message<'a, R, T>(&self, req: R) -> Result<T>
where
R: Into<RequestType<'a>>,
T: DeserializeOwned,
{
async fn send(
id_counter: &AtomicU64,
receivers: &Arc<ReceiverList>,
write: &Mutex<MessageWriter>,
req: RequestType<'_>,
) -> Result<serde_json::Value> {
let id = id_counter.fetch_add(1, Ordering::SeqCst);
let id_str = id.to_string();
let req = ClientRequest::Request(Request {
request_id: &id_str,
ty: req,
});
let json = serde_json::to_string(&req).map_err(crate::error::SerializeMessageError)?;
let rx = receivers.add(id).await;
trace!(%json, "sending message");
let write_result = write
.lock()
.await
.send(Message::text(json))
.await
.map_err(crate::error::SendError);
if let Err(e) = write_result {
receivers.remove(id).await;
return Err(e.into());
}
let (status, resp) = rx.await.map_err(crate::error::ReceiveMessageError)?;
if !status.result {
return Err(Error::Api {
code: status.code,
message: status.comment,
});
}
Ok(resp)
}
let resp = send(&self.id_counter, &self.receivers, &self.write, req.into()).await?;
serde_json::from_value(resp)
.map_err(crate::error::DeserializeResponseError)
.map_err(Into::into)
}
pub fn disconnect(&mut self) -> impl Future + use<> {
let handle = self.handle.take().inspect(|h| {
h.abort();
});
async {
if let Some(h) = handle {
h.await.ok();
}
}
}
pub async fn reidentify(&self, event_subscriptions: EventSubscription) -> Result<()> {
let json = serde_json::to_string(&ClientRequest::Reidentify(Reidentify {
event_subscriptions: Some(event_subscriptions),
}))
.map_err(crate::error::SerializeMessageError)?;
let rx = self.reidentify_receivers.add().await;
self.write
.lock()
.await
.send(Message::text(json))
.await
.map_err(crate::error::SendError)?;
let resp = rx.await.map_err(crate::error::ReceiveMessageError)?;
debug!(
rpc_version = %resp.negotiated_rpc_version,
"re-identified against obs-websocket",
);
Ok(())
}
#[cfg(feature = "events")]
pub fn events(&self) -> Result<crate::events::EventStream> {
if let Some(sender) = &self.event_sender.upgrade() {
let receiver = sender.subscribe();
Ok(crate::events::EventStream::new(receiver))
} else {
Err(crate::error::Error::Disconnected)
}
}
pub fn canvases(&self) -> Canvases<'_> {
Canvases { client: self }
}
pub fn config(&self) -> Config<'_> {
Config { client: self }
}
pub fn filters(&self) -> Filters<'_> {
Filters { client: self }
}
pub fn general(&self) -> General<'_> {
General { client: self }
}
pub fn hotkeys(&self) -> Hotkeys<'_> {
Hotkeys { client: self }
}
pub fn inputs(&self) -> Inputs<'_> {
Inputs { client: self }
}
pub fn media_inputs(&self) -> MediaInputs<'_> {
MediaInputs { client: self }
}
pub fn outputs(&self) -> Outputs<'_> {
Outputs { client: self }
}
pub fn profiles(&self) -> Profiles<'_> {
Profiles { client: self }
}
pub fn recording(&self) -> Recording<'_> {
Recording { client: self }
}
pub fn replay_buffer(&self) -> ReplayBuffer<'_> {
ReplayBuffer { client: self }
}
pub fn scene_collections(&self) -> SceneCollections<'_> {
SceneCollections { client: self }
}
pub fn scene_items(&self) -> SceneItems<'_> {
SceneItems { client: self }
}
pub fn scenes(&self) -> Scenes<'_> {
Scenes { client: self }
}
pub fn sources(&self) -> Sources<'_> {
Sources { client: self }
}
pub fn streaming(&self) -> Streaming<'_> {
Streaming { client: self }
}
pub fn transitions(&self) -> Transitions<'_> {
Transitions { client: self }
}
pub fn ui(&self) -> Ui<'_> {
Ui { client: self }
}
pub fn virtual_cam(&self) -> VirtualCam<'_> {
VirtualCam { client: self }
}
}
impl Drop for Client {
fn drop(&mut self) {
drop(self.disconnect());
}
}
async fn recv_loop(
mut read: impl Stream<Item = Result<Message, tokio_websockets::Error>> + Unpin,
#[cfg(feature = "events")] events_tx: Arc<broadcast::Sender<Event>>,
receivers: Arc<ReceiverList>,
reidentify_receivers: Arc<ReidentifyReceiverList>,
) {
while let Some(Ok(msg)) = read.next().await {
if let Some((_, reason)) = msg.as_close() {
if !reason.is_empty() {
info!(%reason, "connection closed with reason");
}
#[cfg(feature = "events")]
events_tx.send(Event::ServerStopping).ok();
continue;
}
let res: Result<(), InnerError> = async {
let text = msg.as_text().ok_or(InnerError::IntoText)?;
let message = serde_json::from_str::<ServerMessage>(text)
.map_err(InnerError::DeserializeMessage)?;
match message {
ServerMessage::RequestResponse(response) => {
trace!(
id = %response.id,
status = ?response.status,
data = %response.data,
"got request-response message",
);
receivers.notify(response).await?;
}
#[cfg(feature = "events")]
ServerMessage::Event(event) => {
trace!(?event, "got OBS event");
events_tx.send(event).ok();
}
#[cfg(not(feature = "events"))]
ServerMessage::Event => {
trace!("got OBS event");
}
ServerMessage::Identified(identified) => {
trace!(?identified, "got identified message");
reidentify_receivers.notify(identified).await;
}
_ => {
trace!(?message, "got unexpected message");
return Err(InnerError::UnexpectedMessage(message));
}
}
Ok(())
}
.await;
if let Err(error) = res {
error!(?error, "failed handling message");
}
}
#[cfg(feature = "events")]
events_tx.send(Event::ServerStopped).ok();
receivers.reset().await;
reidentify_receivers.reset().await;
}