use std::collections::{HashMap, VecDeque};
use base64::engine::Config;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use tokio::{
sync::{Mutex, oneshot},
time::{self, Duration},
};
use tokio_websockets::Message;
pub use tokio_websockets::proto::CloseCode;
use tracing::debug;
use super::InnerError;
use crate::{
requests::{ClientRequest, EventSubscription, Identify},
responses::{Hello, Identified, RequestResponse, ServerMessage, Status},
};
#[derive(Default)]
pub(super) struct ReceiverList(Mutex<HashMap<u64, oneshot::Sender<(Status, serde_json::Value)>>>);
impl ReceiverList {
pub async fn add(&self, id: u64) -> oneshot::Receiver<(Status, serde_json::Value)> {
let (tx, rx) = oneshot::channel();
self.0.lock().await.insert(id, tx);
rx
}
pub async fn remove(&self, id: u64) {
self.0.lock().await.remove(&id);
}
pub async fn notify(&self, response: RequestResponse) -> Result<(), InnerError> {
let RequestResponse {
r#type: _,
id,
status,
data,
} = response;
let id = id
.parse()
.map_err(|e| InnerError::InvalidRequestId(e, id))?;
if let Some(tx) = self.0.lock().await.remove(&id) {
tx.send((status, data)).ok();
}
Ok(())
}
pub async fn reset(&self) {
self.0.lock().await.clear();
}
}
#[derive(Default)]
pub(super) struct ReidentifyReceiverList(Mutex<VecDeque<oneshot::Sender<Identified>>>);
impl ReidentifyReceiverList {
pub async fn add(&self) -> oneshot::Receiver<Identified> {
let (tx, rx) = oneshot::channel();
self.0.lock().await.push_back(tx);
rx
}
pub async fn notify(&self, identified: Identified) {
if let Some(tx) = self.0.lock().await.pop_front() {
tx.send(identified).ok();
}
}
pub async fn reset(&self) {
self.0.lock().await.clear();
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum HandshakeError {
#[error("connection to obs-websocket was closed: {}", match .0 {
Some(details) => &details.reason,
None => "no details provided",
})]
ConnectionClosed(Option<CloseDetails>),
#[error("failed reading websocket message")]
Receive(#[from] ReceiveError),
#[error("websocket message not convertible to text")]
IntoText,
#[error("failed deserializing message")]
DeserializeMessage(#[from] crate::error::DeserializeResponseError),
#[error("failed serializing message")]
SerializeMessage(#[from] crate::error::SerializeMessageError),
#[error("failed to send message to obs-websocket")]
Send(#[from] crate::error::SendError),
#[error("didn't receive a `Hello` message after connecting")]
NoHello,
#[error("didn't receive a `Identified` message")]
NoIdentified,
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ReceiveError(tokio_websockets::Error);
#[derive(Debug)]
pub struct CloseDetails {
pub code: CloseCode,
pub reason: String,
}
pub(super) async fn handshake(
write: &mut (impl Sink<Message, Error = tokio_websockets::Error> + Unpin),
read: &mut (impl Stream<Item = Result<Message, tokio_websockets::Error>> + Unpin),
password: Option<&str>,
event_subscriptions: Option<EventSubscription>,
) -> Result<(), HandshakeError> {
async fn read_message(
read: &mut (impl Stream<Item = Result<Message, tokio_websockets::Error>> + Unpin),
) -> Result<ServerMessage, HandshakeError> {
let message = read
.next()
.await
.ok_or(HandshakeError::ConnectionClosed(None))?
.map_err(ReceiveError)?;
if let Some((code, reason)) = message.as_close() {
return Err(HandshakeError::ConnectionClosed(Some(CloseDetails {
code,
reason: reason.to_owned(),
})));
}
let message = message.as_text().ok_or(HandshakeError::IntoText)?;
serde_json::from_str::<ServerMessage>(message)
.map_err(crate::error::DeserializeResponseError)
.map_err(Into::into)
}
let server_message = time::timeout(Duration::from_secs(5), read_message(read))
.await
.map_err(|_| HandshakeError::NoHello)?;
match server_message? {
ServerMessage::Hello(Hello {
obs_web_socket_version: _,
rpc_version,
authentication,
}) => {
let authentication = authentication.zip(password).map(|(auth, password)| {
create_auth_response(&auth.challenge, &auth.salt, password)
});
let req = serde_json::to_string(&ClientRequest::Identify(Identify {
rpc_version,
authentication,
event_subscriptions,
}))
.map_err(crate::error::SerializeMessageError)?;
write
.send(Message::text(req))
.await
.map_err(crate::error::SendError)?;
}
_ => return Err(HandshakeError::NoHello),
}
match read_message(read).await? {
ServerMessage::Identified(Identified {
negotiated_rpc_version,
}) => {
debug!(rpc_version = %negotiated_rpc_version, "identified against obs-websocket");
}
_ => return Err(HandshakeError::NoIdentified),
}
Ok(())
}
fn create_auth_response(challenge: &str, salt: &str, password: &str) -> String {
use base64::engine::{Engine, general_purpose};
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.update(salt.as_bytes());
let mut auth = String::with_capacity(
base64::encoded_len(
Sha256::output_size(),
general_purpose::STANDARD.config().encode_padding(),
)
.unwrap_or_default(),
);
general_purpose::STANDARD.encode_string(hasher.finalize_reset(), &mut auth);
hasher.update(auth.as_bytes());
hasher.update(challenge.as_bytes());
auth.clear();
general_purpose::STANDARD.encode_string(hasher.finalize(), &mut auth);
auth
}