bitrouter_core/hooks/
model.rs1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::GenerationHook;
6
7pub struct HookedModel {
14 pub(crate) inner: Box<DynLanguageModel<'static>>,
15 pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19 pub fn new(
20 inner: Box<DynLanguageModel<'static>>,
21 hooks: Arc<[Arc<dyn GenerationHook>]>,
22 ) -> Self {
23 Self { inner, hooks }
24 }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28 fn provider_name(&self) -> &str {
29 self.inner.provider_name()
30 }
31
32 fn model_id(&self) -> &str {
33 self.inner.model_id()
34 }
35
36 async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37 self.inner.supported_urls().await
38 }
39
40 async fn generate(
41 &self,
42 options: crate::models::language::call_options::LanguageModelCallOptions,
43 ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44 {
45 let result = match self.inner.generate(options).await {
46 Ok(r) => r,
47 Err(e) => {
48 for hook in self.hooks.iter() {
49 hook.on_generate_error(&e);
50 }
51 return Err(e);
52 }
53 };
54
55 for hook in self.hooks.iter() {
56 hook.on_generate_result(&result);
57 }
58
59 Ok(result)
60 }
61
62 async fn stream(
63 &self,
64 options: crate::models::language::call_options::LanguageModelCallOptions,
65 ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
66 {
67 let result = match self.inner.stream(options).await {
68 Ok(r) => r,
69 Err(e) => {
70 for hook in self.hooks.iter() {
71 hook.on_generate_error(&e);
72 }
73 return Err(e);
74 }
75 };
76
77 let hooked_stream = super::stream::HookedStream::new(result.stream, self.hooks.clone());
78
79 Ok(
80 crate::models::language::stream_result::LanguageModelStreamResult {
81 stream: Box::pin(hooked_stream),
82 request: result.request,
83 response: result.response,
84 },
85 )
86 }
87}