cognate_core/middleware.rs
1//! Tower-inspired middleware system for Cognate providers.
2//!
3//! Wrap any [`Provider`] with additional behaviour — logging, metrics,
4//! retries, or timeouts — without modifying the provider itself.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use cognate_core::{Provider, ProviderExt, middleware::{TracingMiddleware, MiddlewareLayer}};
10//! use std::sync::Arc;
11//!
12//! fn wrap<P: Provider + 'static>(provider: P) -> Arc<dyn Provider> {
13//! provider.with(MiddlewareLayer::new(TracingMiddleware::default()))
14//! }
15//! ```
16
17use async_trait::async_trait;
18use crate::{Provider, Request, Response, Result, Chunk};
19use futures::stream::BoxStream;
20use std::sync::Arc;
21
22// ─── Layer / Middleware traits ─────────────────────────────────────────────
23
24/// A factory that wraps a [`Provider`] to produce a new [`Provider`].
25pub trait Layer: Send + Sync {
26 /// Wrap `provider` with this layer and return the decorated provider.
27 fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider>;
28}
29
30/// A middleware that intercepts completion and streaming requests.
31///
32/// Default implementations delegate directly to `next`, so you only need
33/// to override the methods you care about.
34#[async_trait]
35pub trait Middleware: Send + Sync {
36 /// Intercept a non-streaming completion request.
37 async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
38 next.complete(req).await
39 }
40
41 /// Intercept a streaming request.
42 async fn stream(
43 &self,
44 req: Request,
45 next: &dyn Provider,
46 ) -> Result<BoxStream<'static, Result<Chunk>>> {
47 next.stream(req).await
48 }
49}
50
51// ─── MiddlewareProvider ────────────────────────────────────────────────────
52
53/// A [`Provider`] that applies a [`Middleware`] before delegating to an inner
54/// provider.
55pub struct MiddlewareProvider<M> {
56 middleware: M,
57 inner: Arc<dyn Provider>,
58}
59
60#[async_trait]
61impl<M: Middleware> Provider for MiddlewareProvider<M> {
62 async fn complete(&self, req: Request) -> Result<Response> {
63 self.middleware.complete(req, self.inner.as_ref()).await
64 }
65
66 async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
67 self.middleware.stream(req, self.inner.as_ref()).await
68 }
69}
70
71// ─── MiddlewareLayer ───────────────────────────────────────────────────────
72
73/// A [`Layer`] that applies a concrete [`Middleware`] implementation.
74pub struct MiddlewareLayer<M> {
75 middleware: M,
76}
77
78impl<M> MiddlewareLayer<M> {
79 /// Create a new layer wrapping `middleware`.
80 pub fn new(middleware: M) -> Self {
81 Self { middleware }
82 }
83}
84
85impl<M: Middleware + Clone + 'static> Layer for MiddlewareLayer<M> {
86 fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider> {
87 Arc::new(MiddlewareProvider {
88 middleware: self.middleware.clone(),
89 inner: provider,
90 })
91 }
92}
93
94// ─── ProviderExt ───────────────────────────────────────────────────────────
95
96/// Extension trait that adds a fluent `.with(layer)` combinator to any [`Provider`].
97pub trait ProviderExt: Provider {
98 /// Wrap `self` with the given [`Layer`], returning a new [`Arc<dyn Provider>`].
99 fn with<L: Layer>(self, layer: L) -> Arc<dyn Provider>
100 where
101 Self: Sized + 'static,
102 {
103 layer.layer(Arc::new(self))
104 }
105}
106
107/// Blanket implementation so every [`Provider`] automatically gets `.with()`.
108impl<P: Provider + Sized + 'static> ProviderExt for P {}
109
110// ─── Built-in middleware ───────────────────────────────────────────────────
111
112/// Middleware that logs every request and response using [`tracing`].
113///
114/// Add to any provider with:
115/// ```rust,no_run
116/// use cognate_core::{ProviderExt, middleware::{TracingMiddleware, MiddlewareLayer}};
117/// # use cognate_core::MockProvider;
118/// let provider = MockProvider::new()
119/// .with(MiddlewareLayer::new(TracingMiddleware::default()));
120/// ```
121#[derive(Clone, Default)]
122pub struct TracingMiddleware;
123
124#[async_trait]
125impl Middleware for TracingMiddleware {
126 async fn complete(&self, req: Request, next: &dyn Provider) -> Result<Response> {
127 let model = req.model.clone();
128 tracing::info!(model = %model, "cognate: sending completion request");
129 let response = next.complete(req).await?;
130 if let Some(usage) = response.usage() {
131 tracing::info!(
132 model = %model,
133 prompt_tokens = usage.prompt_tokens,
134 completion_tokens = usage.completion_tokens,
135 "cognate: request completed"
136 );
137 }
138 Ok(response)
139 }
140
141 async fn stream(
142 &self,
143 req: Request,
144 next: &dyn Provider,
145 ) -> Result<BoxStream<'static, Result<Chunk>>> {
146 tracing::info!(model = %req.model, "cognate: starting streaming request");
147 next.stream(req).await
148 }
149}