use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use crate::client::{
EndpointRef, SippChatRequest, SippClient, SippEmbedRequest, SippEmbeddingResponse,
SippEmbeddingRun, SippQueryRequest, SippTextResponse, SippTextRun,
};
use crate::core::{FinishReason, TokenBatch, TokenUsage};
use futures_util::future::{select, Either};
use futures_util::{stream, Stream, StreamExt};
use crate::gateway_core::{GatewayError, GatewayRequestContext, GatewayResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operation {
Query,
Chat,
Embed,
}
pub trait TargetResolver: Send + Sync {
fn resolve(
&self,
context: &GatewayRequestContext,
target: &str,
operation: Operation,
) -> GatewayResult<EndpointRef>;
}
pub trait Authorizer: Send + Sync {
fn authorize(
&self,
context: &GatewayRequestContext,
target: &str,
endpoint: &EndpointRef,
operation: Operation,
) -> GatewayResult<()>;
}
pub trait AdmissionPermit: Send {}
impl<T: Send> AdmissionPermit for T {}
pub trait AdmissionController: Send + Sync {
fn acquire(
&self,
context: &GatewayRequestContext,
target: &str,
endpoint: &EndpointRef,
operation: Operation,
) -> GatewayResult<Box<dyn AdmissionPermit>>;
}
pub trait GatewayExecutor: Send + Sync {
fn query(&self, context: &GatewayRequestContext, request: SippQueryRequest) -> SippTextRun;
fn chat(&self, context: &GatewayRequestContext, request: SippChatRequest) -> SippTextRun;
fn embed(&self, context: &GatewayRequestContext, request: SippEmbedRequest)
-> SippEmbeddingRun;
}
#[derive(Clone)]
pub struct SippClientExecutor {
client: Arc<SippClient>,
}
impl SippClientExecutor {
pub fn new(client: SippClient) -> Self {
Self {
client: Arc::new(client),
}
}
pub fn from_shared(client: Arc<SippClient>) -> Self {
Self { client }
}
}
impl GatewayExecutor for SippClientExecutor {
fn query(&self, context: &GatewayRequestContext, request: SippQueryRequest) -> SippTextRun {
self.client
.query_with_context(context.client_context(), request)
}
fn chat(&self, context: &GatewayRequestContext, request: SippChatRequest) -> SippTextRun {
self.client
.chat_with_context(context.client_context(), request)
}
fn embed(
&self,
context: &GatewayRequestContext,
request: SippEmbedRequest,
) -> SippEmbeddingRun {
self.client
.embed_with_context(context.client_context(), request)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AllowAllAuthorizer;
impl Authorizer for AllowAllAuthorizer {
fn authorize(
&self,
_context: &GatewayRequestContext,
_target: &str,
_endpoint: &EndpointRef,
_operation: Operation,
) -> GatewayResult<()> {
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UnlimitedAdmissionController;
impl AdmissionController for UnlimitedAdmissionController {
fn acquire(
&self,
_context: &GatewayRequestContext,
_target: &str,
_endpoint: &EndpointRef,
_operation: Operation,
) -> GatewayResult<Box<dyn AdmissionPermit>> {
Ok(Box::new(()))
}
}
#[derive(Clone)]
pub struct GatewayPipeline {
resolver: Arc<dyn TargetResolver>,
authorizer: Arc<dyn Authorizer>,
admission: Arc<dyn AdmissionController>,
executor: Arc<dyn GatewayExecutor>,
}
impl GatewayPipeline {
pub fn new(
resolver: Arc<dyn TargetResolver>,
authorizer: Arc<dyn Authorizer>,
admission: Arc<dyn AdmissionController>,
executor: Arc<dyn GatewayExecutor>,
) -> Self {
Self {
resolver,
authorizer,
admission,
executor,
}
}
pub async fn query(
&self,
context: &GatewayRequestContext,
target: &str,
mut request: SippQueryRequest,
) -> GatewayResult<SippTextResponse> {
let (endpoint, permit) = self.prepare(context, target, Operation::Query)?;
request.endpoint = Some(endpoint);
let run = self.executor.query(context, request);
context.cancellation.register(run.cancellation_handle());
let result = run.await.map_err(GatewayError::from);
drop(permit);
result
}
pub fn stream_query(
&self,
context: &GatewayRequestContext,
target: &str,
mut request: SippQueryRequest,
) -> GatewayResult<GatewayStream> {
let (endpoint, permit) = self.prepare(context, target, Operation::Query)?;
request.endpoint = Some(endpoint);
request.emit_tokens = true;
Ok(text_stream(
context,
self.executor.query(context, request),
permit,
))
}
pub async fn chat(
&self,
context: &GatewayRequestContext,
target: &str,
mut request: SippChatRequest,
) -> GatewayResult<SippTextResponse> {
let (endpoint, permit) = self.prepare(context, target, Operation::Chat)?;
request.endpoint = Some(endpoint);
let run = self.executor.chat(context, request);
context.cancellation.register(run.cancellation_handle());
let result = run.await.map_err(GatewayError::from);
drop(permit);
result
}
pub fn stream_chat(
&self,
context: &GatewayRequestContext,
target: &str,
mut request: SippChatRequest,
) -> GatewayResult<GatewayStream> {
let (endpoint, permit) = self.prepare(context, target, Operation::Chat)?;
request.endpoint = Some(endpoint);
request.emit_tokens = true;
Ok(text_stream(
context,
self.executor.chat(context, request),
permit,
))
}
pub async fn embed(
&self,
context: &GatewayRequestContext,
target: &str,
mut request: SippEmbedRequest,
) -> GatewayResult<SippEmbeddingResponse> {
let (endpoint, permit) = self.prepare(context, target, Operation::Embed)?;
request.endpoint = Some(endpoint);
let run = self.executor.embed(context, request);
context.cancellation.register(run.cancellation_handle());
let result = run.await.map_err(GatewayError::from);
drop(permit);
result
}
fn prepare(
&self,
context: &GatewayRequestContext,
target: &str,
operation: Operation,
) -> GatewayResult<(EndpointRef, Box<dyn AdmissionPermit>)> {
let endpoint = self.resolver.resolve(context, target, operation)?;
self.authorizer
.authorize(context, target, &endpoint, operation)?;
let permit = self
.admission
.acquire(context, target, &endpoint, operation)?;
Ok((endpoint, permit))
}
}
pub type GatewayStream = Pin<Box<dyn Stream<Item = GatewayResult<GatewayStreamEvent>> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub enum GatewayStreamEvent {
TokenBatch(TokenBatch),
Usage(TokenUsage),
Finished {
finish_reason: FinishReason,
metadata: crate::client::SippResponseMetadata,
},
}
struct TextStreamState {
tokens: crate::client::SippTokenBatches,
response: Option<crate::client::SippTextResponseFuture>,
pending: VecDeque<GatewayResult<GatewayStreamEvent>>,
terminal: bool,
permit: Option<Box<dyn AdmissionPermit>>,
}
fn text_stream(
context: &GatewayRequestContext,
run: SippTextRun,
permit: Box<dyn AdmissionPermit>,
) -> GatewayStream {
let (tokens, response, cancellation) = run.into_parts_with_cancel();
context.cancellation.register(cancellation);
let state = TextStreamState {
tokens,
response: Some(response),
pending: VecDeque::new(),
terminal: false,
permit: Some(permit),
};
Box::pin(stream::unfold(state, |mut state| async move {
if let Some(event) = state.pending.pop_front() {
return Some((event, state));
}
if state.terminal {
return None;
}
let response = state.response.take()?;
match select(state.tokens.next(), response).await {
Either::Left((Some(batch), response)) => {
state.response = Some(response);
Some((Ok(GatewayStreamEvent::TokenBatch(batch)), state))
}
Either::Left((None, response)) => {
finish_stream(&mut state, response.await);
state.pending.pop_front().map(|event| (event, state))
}
Either::Right((response, tokens)) => {
drop(tokens);
finish_stream(&mut state, response);
state.pending.pop_front().map(|event| (event, state))
}
}
}))
}
fn finish_stream(
state: &mut TextStreamState,
response: crate::client::SippResult<SippTextResponse>,
) {
state.terminal = true;
state.permit.take();
match response {
Ok(response) => {
if let Some(usage) = response.usage {
state
.pending
.push_back(Ok(GatewayStreamEvent::Usage(usage)));
}
state.pending.push_back(Ok(GatewayStreamEvent::Finished {
finish_reason: response.finish_reason,
metadata: response.metadata,
}));
}
Err(error) => state.pending.push_back(Err(error.into())),
}
}
#[cfg(test)]
#[path = "../tests/gateway_core/pipeline_tests.rs"]
mod pipeline_tests;