use ort::session::Session;
use std::ops::{Deref, DerefMut};
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
sender: async_channel::Sender<T>,
receiver: async_channel::Receiver<T>,
total: usize,
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
let (sender, receiver) = async_channel::bounded(total.max(1));
for item in items {
sender
.try_send(item)
.expect("channel capacity matches item count");
}
Self {
sender,
receiver,
total,
}
}
pub async fn checkout(&self) -> Result<PoolGuard<'_, T>, PoolError> {
match self.receiver.recv().await {
Ok(item) => Ok(PoolGuard {
pool: self,
item: Some(item),
}),
Err(_) => Err(PoolError::Closed),
}
}
pub fn checkout_blocking(&self) -> Result<PoolGuard<'_, T>, PoolError> {
match self.receiver.recv_blocking() {
Ok(item) => Ok(PoolGuard {
pool: self,
item: Some(item),
}),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.sender.close();
self.receiver.close();
}
pub fn total(&self) -> usize {
self.total
}
pub fn available(&self) -> usize {
self.receiver.len()
}
}
pub struct PoolGuard<'a, T> {
pool: &'a Pool<T>,
item: Option<T>,
}
impl<T> PoolGuard<'_, T> {
pub fn into_owned(mut self) -> (T, OwnedReservation<T>) {
let item = self
.item
.take()
.expect("PoolGuard::into_owned called after drop");
let reservation = OwnedReservation {
sender: self.pool.sender.clone(),
};
(item, reservation)
}
}
impl<T> Deref for PoolGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> DerefMut for PoolGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> Drop for PoolGuard<'_, T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
let _ = self.pool.sender.try_send(item);
}
}
}
pub struct OwnedReservation<T> {
sender: async_channel::Sender<T>,
}
impl<T> OwnedReservation<T> {
pub fn checkin(self, item: T) {
let _ = self.sender.try_send(item);
}
pub fn guard(self, item: T) -> PoolItemGuard<T> {
PoolItemGuard {
reservation: self,
item: Some(item),
}
}
}
pub struct PoolItemGuard<T> {
reservation: OwnedReservation<T>,
item: Option<T>,
}
impl<T> PoolItemGuard<T> {
pub fn item_mut(&mut self) -> &mut T {
self.item
.as_mut()
.expect("PoolItemGuard item already taken")
}
pub fn item(&self) -> &T {
self.item
.as_ref()
.expect("PoolItemGuard item already taken")
}
pub fn into_inner(mut self) -> T {
self.item.take().expect("PoolItemGuard item already taken")
}
}
impl<T> Deref for PoolItemGuard<T> {
type Target = T;
fn deref(&self) -> &T {
self.item()
}
}
impl<T> DerefMut for PoolItemGuard<T> {
fn deref_mut(&mut self) -> &mut T {
self.item_mut()
}
}
impl<T> Drop for PoolItemGuard<T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
let _ = self.reservation.sender.try_send(item);
}
}
}