bitrouter_core/hooks/
model.rs1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::{GenerationContext, GenerationHook};
6
7pub struct HookedModel {
14 pub(crate) inner: Box<DynLanguageModel<'static>>,
15 pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19 pub fn new(
20 inner: Box<DynLanguageModel<'static>>,
21 hooks: Arc<[Arc<dyn GenerationHook>]>,
22 ) -> Self {
23 Self { inner, hooks }
24 }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28 fn provider_name(&self) -> &str {
29 self.inner.provider_name()
30 }
31
32 fn model_id(&self) -> &str {
33 self.inner.model_id()
34 }
35
36 async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37 self.inner.supported_urls().await
38 }
39
40 async fn generate(
41 &self,
42 options: crate::models::language::call_options::LanguageModelCallOptions,
43 ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44 {
45 let result = match self.inner.generate(options).await {
46 Ok(r) => r,
47 Err(e) => {
48 for hook in self.hooks.iter() {
49 hook.on_generate_error(&e);
50 }
51 return Err(e);
52 }
53 };
54
55 let ctx = GenerationContext {
56 model_id: self.inner.model_id(),
57 provider_name: self.inner.provider_name(),
58 };
59
60 for hook in self.hooks.iter() {
61 hook.on_generate_result(&ctx, &result);
62 }
63
64 Ok(result)
65 }
66
67 async fn stream(
68 &self,
69 options: crate::models::language::call_options::LanguageModelCallOptions,
70 ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
71 {
72 let result = match self.inner.stream(options).await {
73 Ok(r) => r,
74 Err(e) => {
75 for hook in self.hooks.iter() {
76 hook.on_generate_error(&e);
77 }
78 return Err(e);
79 }
80 };
81
82 let hooked_stream = super::stream::HookedStream::new(
83 result.stream,
84 self.hooks.clone(),
85 self.inner.model_id().to_owned(),
86 self.inner.provider_name().to_owned(),
87 );
88
89 Ok(
90 crate::models::language::stream_result::LanguageModelStreamResult {
91 stream: Box::pin(hooked_stream),
92 request: result.request,
93 response: result.response,
94 },
95 )
96 }
97}