cognate_axum/lib.rs
1//! Axum integration for Cognate.
2//!
3//! # Features
4//!
5//! * [`CognateProvider`] — an Axum extractor that pulls an `Arc<dyn Provider>`
6//! out of shared application state via [`axum::extract::FromRef`].
7//! * [`UsageLayer`] / [`UsageHandle`] — Tower middleware that accumulates token
8//! usage across all requests.
9//! * [`into_sse`] — convert a `BoxStream<Result<Chunk>>` directly into an Axum
10//! [`Sse`] response.
11//!
12//! # Wiring up the provider
13//!
14//! ```rust,no_run
15//! use axum::{Router, routing::post};
16//! use cognate_axum::CognateProvider;
17//! use cognate_core::{Provider, Request};
18//! use std::sync::Arc;
19//!
20//! #[derive(Clone)]
21//! struct AppState {
22//! provider: Arc<dyn Provider>,
23//! }
24//!
25//! impl axum::extract::FromRef<AppState> for Arc<dyn Provider> {
26//! fn from_ref(state: &AppState) -> Self {
27//! state.provider.clone()
28//! }
29//! }
30//!
31//! async fn chat(
32//! CognateProvider(provider): CognateProvider,
33//! axum::Json(req): axum::Json<Request>,
34//! ) -> String {
35//! provider.complete(req).await.unwrap().content().to_string()
36//! }
37//!
38//! #[tokio::main]
39//! async fn main() {
40//! let provider: Arc<dyn Provider> = unimplemented!();
41//! let _app: Router = Router::new()
42//! .route("/chat", post(chat))
43//! .with_state(AppState { provider });
44//! }
45//! ```
46#![warn(missing_docs)]
47
48use axum::{
49 extract::{FromRef, FromRequestParts},
50 http::request::Parts,
51 response::sse::{Event, Sse},
52};
53use cognate_core::{Chunk, Provider};
54use futures::stream::{BoxStream, Stream, StreamExt};
55use std::{
56 convert::Infallible,
57 sync::{
58 atomic::{AtomicU64, Ordering},
59 Arc,
60 },
61};
62
63// ─── CognateProvider extractor ─────────────────────────────────────────────
64
65/// Axum extractor that resolves an `Arc<dyn Provider>` from shared application
66/// state.
67///
68/// The application state `S` must implement
69/// `axum::extract::FromRef<S, Target = Arc<dyn Provider>>`.
70///
71/// See the [crate documentation](crate) for a full wiring example.
72pub struct CognateProvider(pub Arc<dyn Provider>);
73
74impl<S> FromRequestParts<S> for CognateProvider
75where
76 Arc<dyn Provider>: FromRef<S>,
77 S: Send + Sync,
78{
79 type Rejection = Infallible;
80
81 async fn from_request_parts(
82 _parts: &mut Parts,
83 state: &S,
84 ) -> Result<Self, Infallible> {
85 Ok(CognateProvider(Arc::<dyn Provider>::from_ref(state)))
86 }
87}
88
89// ─── SSE helper ────────────────────────────────────────────────────────────
90
91/// Convert a streaming provider response into an Axum [`Sse`] response.
92///
93/// Text deltas are emitted as `data:` events. Provider errors are emitted as
94/// `event: error` events so the client can detect them.
95///
96/// # Example
97///
98/// ```rust,no_run
99/// use axum::{extract::State, response::Sse};
100/// use cognate_axum::{CognateProvider, into_sse};
101/// use cognate_core::{Request, Message};
102///
103/// async fn stream_handler(
104/// CognateProvider(provider): CognateProvider,
105/// axum::Json(req): axum::Json<Request>,
106/// ) -> impl axum::response::IntoResponse {
107/// let stream = provider.stream(req).await.unwrap();
108/// into_sse(stream)
109/// }
110/// ```
111pub fn into_sse(
112 stream: BoxStream<'static, cognate_core::Result<Chunk>>,
113) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
114 let mapped = stream.map(|result| match result {
115 Ok(chunk) => {
116 let mut event = Event::default().data(chunk.content());
117 if chunk.is_finished() {
118 event = event.event("done");
119 }
120 Ok(event)
121 }
122 Err(e) => Ok(Event::default().event("error").data(e.to_string())),
123 });
124 Sse::new(mapped)
125}
126
127// ─── UsageHandle ───────────────────────────────────────────────────────────
128
129/// A handle for reading token usage accumulated by [`UsageLayer`].
130///
131/// Clone freely — all clones share the same underlying counters.
132#[derive(Debug, Clone, Default)]
133pub struct UsageHandle {
134 /// Total prompt tokens seen across all requests.
135 pub prompt_tokens: Arc<AtomicU64>,
136 /// Total completion tokens seen across all requests.
137 pub completion_tokens: Arc<AtomicU64>,
138}
139
140impl UsageHandle {
141 /// Create a new, zeroed handle.
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 /// Total prompt tokens accumulated so far.
147 pub fn prompt_tokens(&self) -> u64 {
148 self.prompt_tokens.load(Ordering::Relaxed)
149 }
150
151 /// Total completion tokens accumulated so far.
152 pub fn completion_tokens(&self) -> u64 {
153 self.completion_tokens.load(Ordering::Relaxed)
154 }
155
156 /// Total tokens accumulated so far.
157 pub fn total_tokens(&self) -> u64 {
158 self.prompt_tokens() + self.completion_tokens()
159 }
160}
161
162// ─── UsageLayer (Tower Layer) ──────────────────────────────────────────────
163
164/// Tower [`Layer`] that wraps a [`Provider`] and records token usage.
165///
166/// # Example
167///
168/// ```rust,no_run
169/// use cognate_axum::{UsageLayer, UsageHandle};
170/// use cognate_core::{Provider, ProviderExt, middleware::MiddlewareLayer};
171/// # use cognate_core::MockProvider;
172///
173/// let handle = UsageHandle::new();
174/// let layer = UsageLayer::new(handle.clone());
175/// let provider = MockProvider::new().with(layer);
176/// // After requests, inspect: handle.total_tokens()
177/// ```
178#[derive(Clone)]
179pub struct UsageLayer {
180 handle: UsageHandle,
181}
182
183impl UsageLayer {
184 /// Create a new layer that reports into `handle`.
185 pub fn new(handle: UsageHandle) -> Self {
186 Self { handle }
187 }
188}
189
190impl cognate_core::middleware::Layer for UsageLayer {
191 fn layer(&self, provider: Arc<dyn Provider>) -> Arc<dyn Provider> {
192 Arc::new(UsageProvider {
193 inner: provider,
194 handle: self.handle.clone(),
195 })
196 }
197}
198
199// ─── UsageProvider ─────────────────────────────────────────────────────────
200
201struct UsageProvider {
202 inner: Arc<dyn Provider>,
203 handle: UsageHandle,
204}
205
206#[async_trait::async_trait]
207impl Provider for UsageProvider {
208 async fn complete(&self, req: cognate_core::Request) -> cognate_core::Result<cognate_core::Response> {
209 let response = self.inner.complete(req).await?;
210 if let Some(usage) = &response.usage {
211 self.handle
212 .prompt_tokens
213 .fetch_add(usage.prompt_tokens as u64, Ordering::Relaxed);
214 self.handle
215 .completion_tokens
216 .fetch_add(usage.completion_tokens as u64, Ordering::Relaxed);
217 }
218 Ok(response)
219 }
220
221 async fn stream(
222 &self,
223 req: cognate_core::Request,
224 ) -> cognate_core::Result<BoxStream<'static, cognate_core::Result<Chunk>>> {
225 // Streaming doesn't return usage per-chunk; delegate as-is.
226 self.inner.stream(req).await
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use cognate_core::{Choice, Message, MockProvider, ProviderExt, Request, Response, Usage};
234
235 fn make_response_with_usage(prompt: u32, completion: u32) -> Response {
236 Response {
237 id: "r".to_string(),
238 model: "m".to_string(),
239 choices: vec![Choice {
240 index: 0,
241 message: Message::assistant("ok"),
242 finish_reason: Some("stop".to_string()),
243 }],
244 usage: Some(Usage {
245 prompt_tokens: prompt,
246 completion_tokens: completion,
247 total_tokens: prompt + completion,
248 }),
249 created: None,
250 }
251 }
252
253 #[tokio::test]
254 async fn test_usage_layer_accumulates() {
255 let handle = UsageHandle::new();
256 let mock = MockProvider::new();
257 mock.push_response(make_response_with_usage(10, 5));
258 mock.push_response(make_response_with_usage(20, 8));
259
260 let provider = mock.with(UsageLayer::new(handle.clone()));
261
262 let req = Request::new().with_model("test");
263 provider.complete(req.clone()).await.unwrap();
264 provider.complete(req).await.unwrap();
265
266 assert_eq!(handle.prompt_tokens(), 30);
267 assert_eq!(handle.completion_tokens(), 13);
268 assert_eq!(handle.total_tokens(), 43);
269 }
270}