Skip to main content

bitrouter_core/hooks/
mod.rs

1mod model;
2mod router;
3mod stream;
4
5pub use model::HookedModel;
6pub use router::HookedRouter;
7
8use crate::errors::BitrouterError;
9use crate::models::language::{
10    generate_result::LanguageModelGenerateResult, stream_part::LanguageModelStreamPart,
11};
12
13/// Identity of the model that handled a generation.
14///
15/// Passed to every [`GenerationHook`] callback so hooks can attribute
16/// results without needing access to the original request.
17#[derive(Debug, Clone)]
18pub struct GenerationContext<'a> {
19    /// Upstream provider model ID (e.g. `"meta-llama/Llama-4-Maverick-17B-128E-Instruct"`).
20    pub model_id: &'a str,
21    /// Provider name (e.g. `"chutes-ai"`).
22    pub provider_name: &'a str,
23}
24
25/// A hook that observes generation lifecycle events for side-effect purposes
26/// (logging, metrics, token tracking, auditing).
27///
28/// Hooks receive borrowed references to core types and must not block.
29/// All methods have default no-op implementations so consumers only
30/// override the events they care about.
31pub trait GenerationHook: Send + Sync {
32    /// Called after a non-streaming `generate()` call completes successfully.
33    fn on_generate_result(
34        &self,
35        _ctx: &GenerationContext<'_>,
36        _result: &LanguageModelGenerateResult,
37    ) {
38    }
39
40    /// Called when `generate()` or `stream()` returns an error.
41    fn on_generate_error(&self, _error: &BitrouterError) {}
42
43    /// Called for each streaming part as it is yielded from the model stream.
44    ///
45    /// To capture token usage from streaming responses, match on
46    /// [`LanguageModelStreamPart::Finish`] which carries
47    /// [`LanguageModelUsage`](crate::models::language::usage::LanguageModelUsage).
48    fn on_stream_part(&self, _ctx: &GenerationContext<'_>, _part: &LanguageModelStreamPart) {}
49}
50
51#[cfg(test)]
52mod tests {
53    use std::pin::Pin;
54    use std::sync::{Arc, atomic::AtomicU32};
55
56    use crate::models::language::{
57        finish_reason::LanguageModelFinishReason,
58        generate_result::{LanguageModelGenerateResult, LanguageModelRawRequest},
59        stream_part::LanguageModelStreamPart,
60        usage::{LanguageModelInputTokens, LanguageModelOutputTokens, LanguageModelUsage},
61    };
62
63    use super::stream::HookedStream;
64    use super::*;
65
66    fn test_usage() -> LanguageModelUsage {
67        LanguageModelUsage {
68            input_tokens: LanguageModelInputTokens {
69                total: Some(10),
70                no_cache: None,
71                cache_read: None,
72                cache_write: None,
73            },
74            output_tokens: LanguageModelOutputTokens {
75                total: Some(20),
76                text: None,
77                reasoning: None,
78            },
79            raw: None,
80        }
81    }
82
83    fn test_generate_result() -> LanguageModelGenerateResult {
84        LanguageModelGenerateResult {
85            content: crate::models::language::content::LanguageModelContent::Text {
86                text: String::new(),
87                provider_metadata: None,
88            },
89            finish_reason: LanguageModelFinishReason::Stop,
90            usage: test_usage(),
91            provider_metadata: None,
92            request: Some(LanguageModelRawRequest {
93                headers: None,
94                body: serde_json::json!({}),
95            }),
96            response_metadata: None,
97            warnings: None,
98        }
99    }
100
101    /// A test hook that counts invocations.
102    struct CountingHook {
103        generate_count: AtomicU32,
104        error_count: AtomicU32,
105        stream_count: AtomicU32,
106    }
107
108    impl CountingHook {
109        fn new() -> Self {
110            Self {
111                generate_count: AtomicU32::new(0),
112                error_count: AtomicU32::new(0),
113                stream_count: AtomicU32::new(0),
114            }
115        }
116    }
117
118    impl GenerationHook for CountingHook {
119        fn on_generate_result(
120            &self,
121            _ctx: &GenerationContext<'_>,
122            _result: &LanguageModelGenerateResult,
123        ) {
124            self.generate_count
125                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
126        }
127
128        fn on_generate_error(&self, _error: &crate::errors::BitrouterError) {
129            self.error_count
130                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
131        }
132
133        fn on_stream_part(&self, _ctx: &GenerationContext<'_>, _part: &LanguageModelStreamPart) {
134            self.stream_count
135                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
136        }
137    }
138
139    #[test]
140    fn default_hook_methods_are_noop() {
141        struct NoopHook;
142        impl GenerationHook for NoopHook {}
143
144        let hook = NoopHook;
145        let ctx = GenerationContext {
146            model_id: "test-model",
147            provider_name: "test-provider",
148        };
149        hook.on_generate_result(&ctx, &test_generate_result());
150        hook.on_generate_error(&crate::errors::BitrouterError::transport(None, "test"));
151        hook.on_stream_part(
152            &ctx,
153            &LanguageModelStreamPart::TextDelta {
154                id: "t1".into(),
155                delta: "hello".into(),
156                provider_metadata: None,
157            },
158        );
159    }
160
161    #[tokio::test]
162    async fn hooked_stream_invokes_hooks_for_each_part() {
163        let hook = Arc::new(CountingHook::new());
164        let hooks: Arc<[Arc<dyn GenerationHook>]> =
165            Arc::from(vec![hook.clone() as Arc<dyn GenerationHook>]);
166
167        let parts = vec![
168            LanguageModelStreamPart::StreamStart {
169                warnings: Vec::new(),
170            },
171            LanguageModelStreamPart::TextDelta {
172                id: "t1".into(),
173                delta: "hello".into(),
174                provider_metadata: None,
175            },
176            LanguageModelStreamPart::TextDelta {
177                id: "t1".into(),
178                delta: " world".into(),
179                provider_metadata: None,
180            },
181            LanguageModelStreamPart::Finish {
182                usage: test_usage(),
183                finish_reason: LanguageModelFinishReason::Stop,
184                provider_metadata: None,
185            },
186        ];
187
188        let inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>> =
189            Box::pin(tokio_stream::iter(parts));
190
191        let hooked = HookedStream::new(
192            inner,
193            hooks,
194            "test-model".to_owned(),
195            "test-provider".to_owned(),
196        );
197        let mut hooked = Box::pin(hooked);
198
199        use tokio_stream::StreamExt as _;
200        let mut collected = Vec::new();
201        while let Some(part) = hooked.next().await {
202            collected.push(part);
203        }
204
205        assert_eq!(collected.len(), 4);
206        assert_eq!(
207            hook.stream_count.load(std::sync::atomic::Ordering::SeqCst),
208            4
209        );
210    }
211
212    #[tokio::test]
213    async fn multiple_hooks_all_invoked() {
214        let hook_a = Arc::new(CountingHook::new());
215        let hook_b = Arc::new(CountingHook::new());
216        let hooks: Arc<[Arc<dyn GenerationHook>]> = Arc::from(vec![
217            hook_a.clone() as Arc<dyn GenerationHook>,
218            hook_b.clone() as Arc<dyn GenerationHook>,
219        ]);
220
221        let parts = vec![
222            LanguageModelStreamPart::TextDelta {
223                id: "t1".into(),
224                delta: "hi".into(),
225                provider_metadata: None,
226            },
227            LanguageModelStreamPart::Finish {
228                usage: test_usage(),
229                finish_reason: LanguageModelFinishReason::Stop,
230                provider_metadata: None,
231            },
232        ];
233
234        let inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>> =
235            Box::pin(tokio_stream::iter(parts));
236
237        let hooked = HookedStream::new(
238            inner,
239            hooks,
240            "test-model".to_owned(),
241            "test-provider".to_owned(),
242        );
243        let mut hooked = Box::pin(hooked);
244
245        use tokio_stream::StreamExt as _;
246        while hooked.next().await.is_some() {}
247
248        assert_eq!(
249            hook_a
250                .stream_count
251                .load(std::sync::atomic::Ordering::SeqCst),
252            2
253        );
254        assert_eq!(
255            hook_b
256                .stream_count
257                .load(std::sync::atomic::Ordering::SeqCst),
258            2
259        );
260    }
261
262    #[test]
263    fn on_generate_result_invoked() {
264        let hook = Arc::new(CountingHook::new());
265        let result = test_generate_result();
266        let ctx = GenerationContext {
267            model_id: "test-model",
268            provider_name: "test-provider",
269        };
270
271        hook.on_generate_result(&ctx, &result);
272        hook.on_generate_result(&ctx, &result);
273
274        assert_eq!(
275            hook.generate_count
276                .load(std::sync::atomic::Ordering::SeqCst),
277            2
278        );
279    }
280
281    #[test]
282    fn on_generate_error_invoked() {
283        let hook = Arc::new(CountingHook::new());
284        let error = crate::errors::BitrouterError::transport(None, "connection failed");
285
286        hook.on_generate_error(&error);
287        hook.on_generate_error(&error);
288
289        assert_eq!(
290            hook.error_count.load(std::sync::atomic::Ordering::SeqCst),
291            2
292        );
293    }
294}