Skip to main content

cognate_core/
middleware.rs

1//! Tower-inspired middleware system for Cognate providers.
2//!
3//! Wrap any [`Provider`] with additional behaviour — logging, metrics,
4//! retries, or timeouts — without modifying the provider itself.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use cognate_core::{Provider, ProviderExt, middleware::{TracingMiddleware, MiddlewareLayer}};
10//! use std::sync::Arc;
11//!
12//! fn wrap<P: Provider + 'static>(provider: P) -> Arc<dyn Provider> {
13//!     provider.with(MiddlewareLayer::new(TracingMiddleware::default()))
14//! }
15//! ```
16
17use async_trait::async_trait;
18use crate::{Provider, Request, Response, Result, Chunk};
19use futures::stream::BoxStream;
20use std::sync::Arc;
21
22// ─── Layer / Middleware traits ─────────────────────────────────────────────
23
24/// A factory that wraps a [`Provider`] to produce a new [`Provider`].
25pub trait Layer: Send + Sync {
26    /// Wrap `provider` with this layer and return the decorated provider.
27    fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider>;
28}
29
30/// A middleware that intercepts completion and streaming requests.
31///
32/// Default implementations delegate directly to `next`, so you only need
33/// to override the methods you care about.
34#[async_trait]
35pub trait Middleware: Send + Sync {
36    /// Intercept a non-streaming completion request.
37    async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
38        next.complete(req).await
39    }
40
41    /// Intercept a streaming request.
42    async fn stream(
43        &self,
44        req: Request,
45        next: &dyn Provider,
46    ) -> Result<BoxStream<'static, Result<Chunk>>> {
47        next.stream(req).await
48    }
49}
50
51// ─── MiddlewareProvider ────────────────────────────────────────────────────
52
53/// A [`Provider`] that applies a [`Middleware`] before delegating to an inner
54/// provider.
55pub struct MiddlewareProvider<M> {
56    middleware: M,
57    inner: Arc<dyn Provider>,
58}
59
60#[async_trait]
61impl<M: Middleware> Provider for MiddlewareProvider<M> {
62    async fn complete(&self, req: Request) -> Result<Response> {
63        self.middleware.complete(req, self.inner.as_ref()).await
64    }
65
66    async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
67        self.middleware.stream(req, self.inner.as_ref()).await
68    }
69}
70
71// ─── MiddlewareLayer ───────────────────────────────────────────────────────
72
73/// A [`Layer`] that applies a concrete [`Middleware`] implementation.
74pub struct MiddlewareLayer<M> {
75    middleware: M,
76}
77
78impl<M> MiddlewareLayer<M> {
79    /// Create a new layer wrapping `middleware`.
80    pub fn new(middleware: M) -> Self {
81        Self { middleware }
82    }
83}
84
85impl<M: Middleware + Clone + 'static> Layer for MiddlewareLayer<M> {
86    fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider> {
87        Arc::new(MiddlewareProvider {
88            middleware: self.middleware.clone(),
89            inner: provider,
90        })
91    }
92}
93
94// ─── ProviderExt ───────────────────────────────────────────────────────────
95
96/// Extension trait that adds a fluent `.with(layer)` combinator to any [`Provider`].
97pub trait ProviderExt: Provider {
98    /// Wrap `self` with the given [`Layer`], returning a new [`Arc<dyn Provider>`].
99    fn with<L: Layer>(self, layer: L) -> Arc<dyn Provider>
100    where
101        Self: Sized + 'static,
102    {
103        layer.layer(Arc::new(self))
104    }
105}
106
107/// Blanket implementation so every [`Provider`] automatically gets `.with()`.
108impl<P: Provider + Sized + 'static> ProviderExt for P {}
109
110// ─── Built-in middleware ───────────────────────────────────────────────────
111
112/// Middleware that logs every request and response using [`tracing`].
113///
114/// Add to any provider with:
115/// ```rust,no_run
116/// use cognate_core::{ProviderExt, middleware::{TracingMiddleware, MiddlewareLayer}};
117/// # use cognate_core::MockProvider;
118/// let provider = MockProvider::new()
119///     .with(MiddlewareLayer::new(TracingMiddleware::default()));
120/// ```
121#[derive(Clone, Default)]
122pub struct TracingMiddleware;
123
124#[async_trait]
125impl Middleware for TracingMiddleware {
126    async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
127        let model = req.model.clone();
128        tracing::info!(model = %model, "cognate: sending completion request");
129        let response = next.complete(req).await?;
130        if let Some(usage) = response.usage() {
131            tracing::info!(
132                model = %model,
133                prompt_tokens = usage.prompt_tokens,
134                completion_tokens = usage.completion_tokens,
135                "cognate: request completed"
136            );
137        }
138        Ok(response)
139    }
140
141    async fn stream(
142        &self,
143        req: Request,
144        next: &dyn Provider,
145    ) -> Result<BoxStream<'static, Result<Chunk>>> {
146        tracing::info!(model = %req.model, "cognate: starting streaming request");
147        next.stream(req).await
148    }
149}