use async_trait::async_trait;
use crate::{Provider, Request, Response, Result, Chunk};
use futures::stream::BoxStream;
use std::sync::Arc;
pub trait Layer: Send + Sync {
fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider>;
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
next.complete(req).await
}
async fn stream(
&self,
req: Request,
next: &dyn Provider,
) -> Result<BoxStream<'static, Result<Chunk>>> {
next.stream(req).await
}
}
pub struct MiddlewareProvider<M> {
middleware: M,
inner: Arc<dyn Provider>,
}
#[async_trait]
impl<M: Middleware> Provider for MiddlewareProvider<M> {
async fn complete(&self, req: Request) -> Result<Response> {
self.middleware.complete(req, self.inner.as_ref()).await
}
async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
self.middleware.stream(req, self.inner.as_ref()).await
}
}
pub struct MiddlewareLayer<M> {
middleware: M,
}
impl<M> MiddlewareLayer<M> {
pub fn new(middleware: M) -> Self {
Self { middleware }
}
}
impl<M: Middleware + Clone + 'static> Layer for MiddlewareLayer<M> {
fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider> {
Arc::new(MiddlewareProvider {
middleware: self.middleware.clone(),
inner: provider,
})
}
}
pub trait ProviderExt: Provider {
fn with<L: Layer>(self, layer: L) -> Arc<dyn Provider>
where
Self: Sized + 'static,
{
layer.layer(Arc::new(self))
}
}
impl<P: Provider + Sized + 'static> ProviderExt for P {}
#[derive(Clone, Default)]
pub struct TracingMiddleware;
#[async_trait]
impl Middleware for TracingMiddleware {
async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
let model = req.model.clone();
tracing::info!(model = %model, "cognate: sending completion request");
let response = next.complete(req).await?;
if let Some(usage) = response.usage() {
tracing::info!(
model = %model,
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
"cognate: request completed"
);
}
Ok(response)
}
async fn stream(
&self,
req: Request,
next: &dyn Provider,
) -> Result<BoxStream<'static, Result<Chunk>>> {
tracing::info!(model = %req.model, "cognate: starting streaming request");
next.stream(req).await
}
}