bitrouter_core/hooks/
model.rs1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::GenerationHook;
6
7pub struct HookedModel {
14 pub(crate) inner: Box<DynLanguageModel<'static>>,
15 pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19 pub fn new(
20 inner: Box<DynLanguageModel<'static>>,
21 hooks: Arc<[Arc<dyn GenerationHook>]>,
22 ) -> Self {
23 Self { inner, hooks }
24 }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28 fn provider_name(&self) -> &str {
29 self.inner.provider_name()
30 }
31
32 fn model_id(&self) -> &str {
33 self.inner.model_id()
34 }
35
36 async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37 self.inner.supported_urls().await
38 }
39
40 async fn generate(
41 &self,
42 options: crate::models::language::call_options::LanguageModelCallOptions,
43 ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44 {
45 let result = self.inner.generate(options).await?;
46
47 for hook in self.hooks.iter() {
48 hook.on_generate_result(&result);
49 }
50
51 Ok(result)
52 }
53
54 async fn stream(
55 &self,
56 options: crate::models::language::call_options::LanguageModelCallOptions,
57 ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
58 {
59 let result = self.inner.stream(options).await?;
60
61 let hooked_stream = super::stream::HookedStream::new(result.stream, self.hooks.clone());
62
63 Ok(
64 crate::models::language::stream_result::LanguageModelStreamResult {
65 stream: Box::pin(hooked_stream),
66 request: result.request,
67 response: result.response,
68 },
69 )
70 }
71}