Skip to main content

cognate_axum/
lib.rs

1//! Axum integration for Cognate.
2//!
3//! # Features
4//!
5//! * [`CognateProvider`] — an Axum extractor that pulls an `Arc<dyn Provider>`
6//!   out of shared application state via [`axum::extract::FromRef`].
7//! * [`UsageLayer`] / [`UsageHandle`] — Tower middleware that accumulates token
8//!   usage across all requests.
9//! * [`into_sse`] — convert a `BoxStream<Result<Chunk>>` directly into an Axum
10//!   [`Sse`] response.
11//!
12//! # Wiring up the provider
13//!
14//! ```rust,no_run
15//! use axum::{Router, routing::post};
16//! use cognate_axum::CognateProvider;
17//! use cognate_core::{Provider, Request};
18//! use std::sync::Arc;
19//!
20//! #[derive(Clone)]
21//! struct AppState {
22//!     provider: Arc<dyn Provider>,
23//! }
24//!
25//! impl axum::extract::FromRef<AppState> for Arc<dyn Provider> {
26//!     fn from_ref(state: &AppState) -> Self {
27//!         state.provider.clone()
28//!     }
29//! }
30//!
31//! async fn chat(
32//!     CognateProvider(provider): CognateProvider,
33//!     axum::Json(req): axum::Json<Request>,
34//! ) -> String {
35//!     provider.complete(req).await.unwrap().content().to_string()
36//! }
37//!
38//! #[tokio::main]
39//! async fn main() {
40//!     let provider: Arc<dyn Provider> = unimplemented!();
41//!     let _app: Router = Router::new()
42//!         .route("/chat", post(chat))
43//!         .with_state(AppState { provider });
44//! }
45//! ```
46#![warn(missing_docs)]
47
48use axum::{
49    extract::{FromRef, FromRequestParts},
50    http::request::Parts,
51    response::sse::{Event, Sse},
52};
53use cognate_core::{Chunk, Provider};
54use futures::stream::{BoxStream, Stream, StreamExt};
55use std::{
56    convert::Infallible,
57    sync::{
58        atomic::{AtomicU64, Ordering},
59        Arc,
60    },
61};
62
63// ─── CognateProvider extractor ─────────────────────────────────────────────
64
65/// Axum extractor that resolves an `Arc<dyn Provider>` from shared application
66/// state.
67///
68/// The application state `S` must implement
69/// `axum::extract::FromRef<S, Target = Arc<dyn Provider>>`.
70///
71/// See the [crate documentation](crate) for a full wiring example.
72pub struct CognateProvider(pub Arc<dyn Provider>);
73
74impl<S> FromRequestParts<S> for CognateProvider
75where
76    Arc<dyn Provider>: FromRef<S>,
77    S: Send + Sync,
78{
79    type Rejection = Infallible;
80
81    async fn from_request_parts(
82        _parts: &mut Parts,
83        state: &S,
84    ) -> Result<Self, Infallible> {
85        Ok(CognateProvider(Arc::<dyn Provider>::from_ref(state)))
86    }
87}
88
89// ─── SSE helper ────────────────────────────────────────────────────────────
90
91/// Convert a streaming provider response into an Axum [`Sse`] response.
92///
93/// Text deltas are emitted as `data:` events.  Provider errors are emitted as
94/// `event: error` events so the client can detect them.
95///
96/// # Example
97///
98/// ```rust,no_run
99/// use axum::{extract::State, response::Sse};
100/// use cognate_axum::{CognateProvider, into_sse};
101/// use cognate_core::{Request, Message};
102///
103/// async fn stream_handler(
104///     CognateProvider(provider): CognateProvider,
105///     axum::Json(req): axum::Json<Request>,
106/// ) -> impl axum::response::IntoResponse {
107///     let stream = provider.stream(req).await.unwrap();
108///     into_sse(stream)
109/// }
110/// ```
111pub fn into_sse(
112    stream: BoxStream<'static, cognate_core::Result<Chunk>>,
113) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
114    let mapped = stream.map(|result| match result {
115        Ok(chunk) => {
116            let mut event = Event::default().data(chunk.content());
117            if chunk.is_finished() {
118                event = event.event("done");
119            }
120            Ok(event)
121        }
122        Err(e) => Ok(Event::default().event("error").data(e.to_string())),
123    });
124    Sse::new(mapped)
125}
126
127// ─── UsageHandle ───────────────────────────────────────────────────────────
128
129/// A handle for reading token usage accumulated by [`UsageLayer`].
130///
131/// Clone freely — all clones share the same underlying counters.
132#[derive(Debug, Clone, Default)]
133pub struct UsageHandle {
134    /// Total prompt tokens seen across all requests.
135    pub prompt_tokens: Arc<AtomicU64>,
136    /// Total completion tokens seen across all requests.
137    pub completion_tokens: Arc<AtomicU64>,
138}
139
140impl UsageHandle {
141    /// Create a new, zeroed handle.
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    /// Total prompt tokens accumulated so far.
147    pub fn prompt_tokens(&self) -> u64 {
148        self.prompt_tokens.load(Ordering::Relaxed)
149    }
150
151    /// Total completion tokens accumulated so far.
152    pub fn completion_tokens(&self) -> u64 {
153        self.completion_tokens.load(Ordering::Relaxed)
154    }
155
156    /// Total tokens accumulated so far.
157    pub fn total_tokens(&self) -> u64 {
158        self.prompt_tokens() + self.completion_tokens()
159    }
160}
161
162// ─── UsageLayer (Tower Layer) ──────────────────────────────────────────────
163
164/// Tower [`Layer`] that wraps a [`Provider`] and records token usage.
165///
166/// # Example
167///
168/// ```rust,no_run
169/// use cognate_axum::{UsageLayer, UsageHandle};
170/// use cognate_core::{Provider, ProviderExt, middleware::MiddlewareLayer};
171/// # use cognate_core::MockProvider;
172///
173/// let handle = UsageHandle::new();
174/// let layer = UsageLayer::new(handle.clone());
175/// let provider = MockProvider::new().with(layer);
176/// // After requests, inspect: handle.total_tokens()
177/// ```
178#[derive(Clone)]
179pub struct UsageLayer {
180    handle: UsageHandle,
181}
182
183impl UsageLayer {
184    /// Create a new layer that reports into `handle`.
185    pub fn new(handle: UsageHandle) -> Self {
186        Self { handle }
187    }
188}
189
190impl cognate_core::middleware::Layer for UsageLayer {
191    fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider> {
192        Arc::new(UsageProvider {
193            inner: provider,
194            handle: self.handle.clone(),
195        })
196    }
197}
198
199// ─── UsageProvider ─────────────────────────────────────────────────────────
200
201struct UsageProvider {
202    inner: Arc<dyn Provider>,
203    handle: UsageHandle,
204}
205
206#[async_trait::async_trait]
207impl Provider for UsageProvider {
208    async fn complete(&self, req: cognate_core::Request) -> cognate_core::Result<cognate_core::Response> {
209        let response = self.inner.complete(req).await?;
210        if let Some(usage) = &response.usage {
211            self.handle
212                .prompt_tokens
213                .fetch_add(usage.prompt_tokens as u64, Ordering::Relaxed);
214            self.handle
215                .completion_tokens
216                .fetch_add(usage.completion_tokens as u64, Ordering::Relaxed);
217        }
218        Ok(response)
219    }
220
221    async fn stream(
222        &self,
223        req: cognate_core::Request,
224    ) -> cognate_core::Result<BoxStream<'static, cognate_core::Result<Chunk>>> {
225        // Streaming doesn't return usage per-chunk; delegate as-is.
226        self.inner.stream(req).await
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use cognate_core::{Choice, Message, MockProvider, ProviderExt, Request, Response, Usage};
234
235    fn make_response_with_usage(prompt: u32, completion: u32) -> Response {
236        Response {
237            id: "r".to_string(),
238            model: "m".to_string(),
239            choices: vec![Choice {
240                index: 0,
241                message: Message::assistant("ok"),
242                finish_reason: Some("stop".to_string()),
243            }],
244            usage: Some(Usage {
245                prompt_tokens: prompt,
246                completion_tokens: completion,
247                total_tokens: prompt + completion,
248            }),
249            created: None,
250        }
251    }
252
253    #[tokio::test]
254    async fn test_usage_layer_accumulates() {
255        let handle = UsageHandle::new();
256        let mock = MockProvider::new();
257        mock.push_response(make_response_with_usage(10, 5));
258        mock.push_response(make_response_with_usage(20, 8));
259
260        let provider = mock.with(UsageLayer::new(handle.clone()));
261
262        let req = Request::new().with_model("test");
263        provider.complete(req.clone()).await.unwrap();
264        provider.complete(req).await.unwrap();
265
266        assert_eq!(handle.prompt_tokens(), 30);
267        assert_eq!(handle.completion_tokens(), 13);
268        assert_eq!(handle.total_tokens(), 43);
269    }
270}