use std::sync::Arc;
use anyhow::Context;
use tokio::sync::oneshot;
use url::Url;
use crate::{Error, Id, NonZeroSlab, State, ffi};
struct TaskEntry {
close: Option<oneshot::Sender<()>>,
callback: ffi::OnStatus,
}
#[derive(Default)]
pub struct Session {
task: NonZeroSlab<Option<TaskEntry>>,
}
impl Session {
pub fn connect(
&mut self,
url: Url,
publish: Option<moq_net::OriginConsumer>,
consume: Option<moq_net::OriginProducer>,
callback: ffi::OnStatus,
) -> Result<Id, Error> {
let closed = oneshot::channel();
let entry = TaskEntry {
close: Some(closed.0),
callback,
};
let id = self.task.insert(Some(entry))?;
tokio::spawn(async move {
let res = tokio::select! {
_ = closed.1 => Ok(()),
res = Self::connect_run(callback, url, publish, consume) => res,
};
let entry = State::lock().session.task.remove(id).flatten();
if let Some(entry) = entry {
entry.callback.call(res);
}
});
Ok(id)
}
async fn connect_run(
callback: ffi::OnStatus,
url: Url,
publish: Option<moq_net::OriginConsumer>,
consume: Option<moq_net::OriginProducer>,
) -> Result<(), Error> {
let reconnect = moq_native::ClientConfig::default()
.init()?
.with_publish(publish)
.with_consume(consume)
.reconnect(url);
Self::report(callback, reconnect).await
}
async fn report(callback: ffi::OnStatus, mut reconnect: moq_native::Reconnect) -> Result<(), Error> {
let mut connects: u64 = 0;
loop {
if let moq_native::Status::Connected = reconnect.status().await.map_err(map_connect_error)? {
connects += 1;
let code = i32::try_from(connects)
.context("connection epoch exceeded i32::MAX")
.map_err(|err| Error::Connect(Arc::new(err)))?;
callback.call(code);
}
}
}
pub fn close(&mut self, id: Id) -> Result<(), Error> {
self.task
.get_mut(id)
.and_then(|entry| entry.as_mut())
.ok_or(Error::SessionNotFound)?
.close
.take()
.ok_or(Error::SessionNotFound)?;
Ok(())
}
}
fn map_connect_error(err: moq_native::Error) -> Error {
match err.connect_error() {
Some(moq_native::ConnectError::Unauthorized) => Error::Unauthorized,
Some(moq_native::ConnectError::Forbidden) => Error::Forbidden,
_ => Error::Connect(Arc::new(err.into())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ffi::ReturnCode;
#[test]
fn maps_native_auth_connect_errors() {
assert!(matches!(
map_connect_error(moq_native::ConnectError::Unauthorized.into()),
Error::Unauthorized
));
assert!(matches!(
map_connect_error(moq_native::ConnectError::Forbidden.into()),
Error::Forbidden
));
assert!(matches!(
map_connect_error(moq_net::Error::Unauthorized.into()),
Error::Unauthorized
));
assert!(matches!(
map_connect_error(moq_native::Error::ConnectFailed),
Error::Connect(_)
));
assert_eq!(Error::Unauthorized.code(), -33);
assert_eq!(Error::Forbidden.code(), -34);
assert_eq!(map_connect_error(moq_native::Error::ConnectFailed).code(), -5);
}
}