use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use crate::core::TokenBatch;
use crate::engine::EngineTokenBatches;
use futures::future::{select, Either};
use futures_channel::mpsc;
use futures_channel::oneshot;
use futures_core::Stream;
use crate::client::{SippEmbeddingResponse, SippError, SippResult, SippTextResponse};
#[cfg(test)]
#[path = "../tests/client/run_tests.rs"]
mod run_tests;
pub type SippTextResponseFuture =
Pin<Box<dyn Future<Output = SippResult<SippTextResponse>> + Send>>;
pub type SippEmbeddingResponseFuture =
Pin<Box<dyn Future<Output = SippResult<SippEmbeddingResponse>> + Send>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SippCancellationReason {
CallerCancelled,
ClientDisconnected,
ServerShutdown,
DeadlineExceeded,
}
impl SippCancellationReason {
pub const fn as_str(self) -> &'static str {
match self {
Self::CallerCancelled => "caller_cancelled",
Self::ClientDisconnected => "client_disconnected",
Self::ServerShutdown => "server_shutdown",
Self::DeadlineExceeded => "deadline_exceeded",
}
}
}
#[derive(Clone)]
pub struct SippCancellationHandle {
sender: Arc<Mutex<Option<oneshot::Sender<SippCancellationReason>>>>,
}
impl SippCancellationHandle {
pub fn cancel(&self, reason: SippCancellationReason) {
let Ok(mut sender) = self.sender.lock() else {
return;
};
if let Some(sender) = sender.take() {
let _ = sender.send(reason);
}
}
}
pub struct SippTextRun {
response: SippTextResponseFuture,
tokens: SippTokenBatches,
cancellation: SippCancellationHandle,
}
impl SippTextRun {
pub(crate) fn new(response: SippTextResponseFuture, tokens: SippTokenBatches) -> Self {
let (response, cancellation) = cancellable_response(response);
Self {
response,
tokens,
cancellation,
}
}
pub fn from_response(response: SippTextResponseFuture) -> Self {
Self::new(response, SippTokenBatches::closed())
}
pub fn from_parts(tokens: SippTokenBatches, response: SippTextResponseFuture) -> Self {
Self::new(response, tokens)
}
pub(crate) fn ready_err(error: SippError) -> Self {
Self::new(
Box::pin(async move { Err(error) }),
SippTokenBatches::closed(),
)
}
pub fn tokens(&mut self) -> &mut SippTokenBatches {
&mut self.tokens
}
pub fn cancellation_handle(&self) -> SippCancellationHandle {
self.cancellation.clone()
}
pub fn cancel(&self, reason: SippCancellationReason) {
self.cancellation.cancel(reason);
}
pub fn into_parts(self) -> (SippTokenBatches, SippTextResponseFuture) {
(self.tokens, self.response)
}
pub fn into_parts_with_cancel(
self,
) -> (
SippTokenBatches,
SippTextResponseFuture,
SippCancellationHandle,
) {
(self.tokens, self.response, self.cancellation)
}
}
impl Future for SippTextRun {
type Output = SippResult<SippTextResponse>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.response.as_mut().poll(cx)
}
}
pub struct SippEmbeddingRun {
response: SippEmbeddingResponseFuture,
cancellation: SippCancellationHandle,
}
impl SippEmbeddingRun {
pub(crate) fn new(response: SippEmbeddingResponseFuture) -> Self {
let (response, cancellation) = cancellable_response(response);
Self {
response,
cancellation,
}
}
pub fn from_response(response: SippEmbeddingResponseFuture) -> Self {
Self::new(response)
}
pub(crate) fn ready_err(error: SippError) -> Self {
Self::new(Box::pin(async move { Err(error) }))
}
pub fn cancellation_handle(&self) -> SippCancellationHandle {
self.cancellation.clone()
}
pub fn cancel(&self, reason: SippCancellationReason) {
self.cancellation.cancel(reason);
}
pub fn into_response(self) -> SippEmbeddingResponseFuture {
self.response
}
pub fn into_parts(self) -> (SippEmbeddingResponseFuture, SippCancellationHandle) {
(self.response, self.cancellation)
}
}
impl Future for SippEmbeddingRun {
type Output = SippResult<SippEmbeddingResponse>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.response.as_mut().poll(cx)
}
}
pub struct SippTokenBatches {
inner: TokenBatchSource,
}
enum TokenBatchSource {
Empty,
Local(EngineTokenBatches),
Receiver(mpsc::UnboundedReceiver<TokenBatch>),
External(Pin<Box<dyn Stream<Item = TokenBatch> + Send>>),
}
impl SippTokenBatches {
pub(crate) fn closed() -> Self {
Self {
inner: TokenBatchSource::Empty,
}
}
pub(crate) fn from_engine(stream: Option<EngineTokenBatches>) -> Self {
match stream {
Some(stream) => Self {
inner: TokenBatchSource::Local(stream),
},
None => Self::closed(),
}
}
pub(crate) fn from_receiver(receiver: mpsc::UnboundedReceiver<TokenBatch>) -> Self {
Self {
inner: TokenBatchSource::Receiver(receiver),
}
}
pub fn from_stream(stream: Pin<Box<dyn Stream<Item = TokenBatch> + Send>>) -> Self {
Self {
inner: TokenBatchSource::External(stream),
}
}
}
impl Stream for SippTokenBatches {
type Item = TokenBatch;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut self.inner {
TokenBatchSource::Empty => Poll::Ready(None),
TokenBatchSource::Local(stream) => Pin::new(stream).poll_next(cx),
TokenBatchSource::Receiver(receiver) => Pin::new(receiver).poll_next(cx),
TokenBatchSource::External(stream) => stream.as_mut().poll_next(cx),
}
}
}
fn cancellable_response<T>(
response: Pin<Box<dyn Future<Output = SippResult<T>> + Send>>,
) -> (
Pin<Box<dyn Future<Output = SippResult<T>> + Send>>,
SippCancellationHandle,
)
where
T: Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let cancellation = SippCancellationHandle {
sender: Arc::new(Mutex::new(Some(sender))),
};
let response = Box::pin(async move {
match select(response, receiver).await {
Either::Left((result, _)) => result,
Either::Right((Ok(reason), response)) => {
drop(response);
Err(SippError::Cancelled { reason })
}
Either::Right((Err(_), response)) => response.await,
}
});
(response, cancellation)
}