1mod model;
2mod router;
3mod stream;
4
5pub use model::HookedModel;
6pub use router::HookedRouter;
7
8use crate::errors::BitrouterError;
9use crate::models::language::{
10 generate_result::LanguageModelGenerateResult, stream_part::LanguageModelStreamPart,
11};
12
13#[derive(Debug, Clone)]
18pub struct GenerationContext<'a> {
19 pub model_id: &'a str,
21 pub provider_name: &'a str,
23}
24
25pub trait GenerationHook: Send + Sync {
32 fn on_generate_result(
34 &self,
35 _ctx: &GenerationContext<'_>,
36 _result: &LanguageModelGenerateResult,
37 ) {
38 }
39
40 fn on_generate_error(&self, _error: &BitrouterError) {}
42
43 fn on_stream_part(&self, _ctx: &GenerationContext<'_>, _part: &LanguageModelStreamPart) {}
49}
50
51#[cfg(test)]
52mod tests {
53 use std::pin::Pin;
54 use std::sync::{Arc, atomic::AtomicU32};
55
56 use crate::models::language::{
57 finish_reason::LanguageModelFinishReason,
58 generate_result::{LanguageModelGenerateResult, LanguageModelRawRequest},
59 stream_part::LanguageModelStreamPart,
60 usage::{LanguageModelInputTokens, LanguageModelOutputTokens, LanguageModelUsage},
61 };
62
63 use super::stream::HookedStream;
64 use super::*;
65
66 fn test_usage() -> LanguageModelUsage {
67 LanguageModelUsage {
68 input_tokens: LanguageModelInputTokens {
69 total: Some(10),
70 no_cache: None,
71 cache_read: None,
72 cache_write: None,
73 },
74 output_tokens: LanguageModelOutputTokens {
75 total: Some(20),
76 text: None,
77 reasoning: None,
78 },
79 raw: None,
80 }
81 }
82
83 fn test_generate_result() -> LanguageModelGenerateResult {
84 LanguageModelGenerateResult {
85 content: crate::models::language::content::LanguageModelContent::Text {
86 text: String::new(),
87 provider_metadata: None,
88 },
89 finish_reason: LanguageModelFinishReason::Stop,
90 usage: test_usage(),
91 provider_metadata: None,
92 request: Some(LanguageModelRawRequest {
93 headers: None,
94 body: serde_json::json!({}),
95 }),
96 response_metadata: None,
97 warnings: None,
98 }
99 }
100
101 struct CountingHook {
103 generate_count: AtomicU32,
104 error_count: AtomicU32,
105 stream_count: AtomicU32,
106 }
107
108 impl CountingHook {
109 fn new() -> Self {
110 Self {
111 generate_count: AtomicU32::new(0),
112 error_count: AtomicU32::new(0),
113 stream_count: AtomicU32::new(0),
114 }
115 }
116 }
117
118 impl GenerationHook for CountingHook {
119 fn on_generate_result(
120 &self,
121 _ctx: &GenerationContext<'_>,
122 _result: &LanguageModelGenerateResult,
123 ) {
124 self.generate_count
125 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
126 }
127
128 fn on_generate_error(&self, _error: &crate::errors::BitrouterError) {
129 self.error_count
130 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
131 }
132
133 fn on_stream_part(&self, _ctx: &GenerationContext<'_>, _part: &LanguageModelStreamPart) {
134 self.stream_count
135 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
136 }
137 }
138
139 #[test]
140 fn default_hook_methods_are_noop() {
141 struct NoopHook;
142 impl GenerationHook for NoopHook {}
143
144 let hook = NoopHook;
145 let ctx = GenerationContext {
146 model_id: "test-model",
147 provider_name: "test-provider",
148 };
149 hook.on_generate_result(&ctx, &test_generate_result());
150 hook.on_generate_error(&crate::errors::BitrouterError::transport(None, "test"));
151 hook.on_stream_part(
152 &ctx,
153 &LanguageModelStreamPart::TextDelta {
154 id: "t1".into(),
155 delta: "hello".into(),
156 provider_metadata: None,
157 },
158 );
159 }
160
161 #[tokio::test]
162 async fn hooked_stream_invokes_hooks_for_each_part() {
163 let hook = Arc::new(CountingHook::new());
164 let hooks: Arc<[Arc<dyn GenerationHook>]> =
165 Arc::from(vec![hook.clone() as Arc<dyn GenerationHook>]);
166
167 let parts = vec![
168 LanguageModelStreamPart::StreamStart {
169 warnings: Vec::new(),
170 },
171 LanguageModelStreamPart::TextDelta {
172 id: "t1".into(),
173 delta: "hello".into(),
174 provider_metadata: None,
175 },
176 LanguageModelStreamPart::TextDelta {
177 id: "t1".into(),
178 delta: " world".into(),
179 provider_metadata: None,
180 },
181 LanguageModelStreamPart::Finish {
182 usage: test_usage(),
183 finish_reason: LanguageModelFinishReason::Stop,
184 provider_metadata: None,
185 },
186 ];
187
188 let inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>> =
189 Box::pin(tokio_stream::iter(parts));
190
191 let hooked = HookedStream::new(
192 inner,
193 hooks,
194 "test-model".to_owned(),
195 "test-provider".to_owned(),
196 );
197 let mut hooked = Box::pin(hooked);
198
199 use tokio_stream::StreamExt as _;
200 let mut collected = Vec::new();
201 while let Some(part) = hooked.next().await {
202 collected.push(part);
203 }
204
205 assert_eq!(collected.len(), 4);
206 assert_eq!(
207 hook.stream_count.load(std::sync::atomic::Ordering::SeqCst),
208 4
209 );
210 }
211
212 #[tokio::test]
213 async fn multiple_hooks_all_invoked() {
214 let hook_a = Arc::new(CountingHook::new());
215 let hook_b = Arc::new(CountingHook::new());
216 let hooks: Arc<[Arc<dyn GenerationHook>]> = Arc::from(vec![
217 hook_a.clone() as Arc<dyn GenerationHook>,
218 hook_b.clone() as Arc<dyn GenerationHook>,
219 ]);
220
221 let parts = vec![
222 LanguageModelStreamPart::TextDelta {
223 id: "t1".into(),
224 delta: "hi".into(),
225 provider_metadata: None,
226 },
227 LanguageModelStreamPart::Finish {
228 usage: test_usage(),
229 finish_reason: LanguageModelFinishReason::Stop,
230 provider_metadata: None,
231 },
232 ];
233
234 let inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>> =
235 Box::pin(tokio_stream::iter(parts));
236
237 let hooked = HookedStream::new(
238 inner,
239 hooks,
240 "test-model".to_owned(),
241 "test-provider".to_owned(),
242 );
243 let mut hooked = Box::pin(hooked);
244
245 use tokio_stream::StreamExt as _;
246 while hooked.next().await.is_some() {}
247
248 assert_eq!(
249 hook_a
250 .stream_count
251 .load(std::sync::atomic::Ordering::SeqCst),
252 2
253 );
254 assert_eq!(
255 hook_b
256 .stream_count
257 .load(std::sync::atomic::Ordering::SeqCst),
258 2
259 );
260 }
261
262 #[test]
263 fn on_generate_result_invoked() {
264 let hook = Arc::new(CountingHook::new());
265 let result = test_generate_result();
266 let ctx = GenerationContext {
267 model_id: "test-model",
268 provider_name: "test-provider",
269 };
270
271 hook.on_generate_result(&ctx, &result);
272 hook.on_generate_result(&ctx, &result);
273
274 assert_eq!(
275 hook.generate_count
276 .load(std::sync::atomic::Ordering::SeqCst),
277 2
278 );
279 }
280
281 #[test]
282 fn on_generate_error_invoked() {
283 let hook = Arc::new(CountingHook::new());
284 let error = crate::errors::BitrouterError::transport(None, "connection failed");
285
286 hook.on_generate_error(&error);
287 hook.on_generate_error(&error);
288
289 assert_eq!(
290 hook.error_count.load(std::sync::atomic::Ordering::SeqCst),
291 2
292 );
293 }
294}