Skip to main content

bitrouter_core/hooks/
model.rs

1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::GenerationHook;
6
7/// A model wrapper that invokes [`GenerationHook`] callbacks after
8/// `generate()` completes and for each streaming part yielded by `stream()`.
9///
10/// The wrapper is a pure observer — it never modifies requests or responses.
11/// The [`LanguageModel`](crate::models::language::language_model::LanguageModel)
12/// implementation lives below in this file so all trait methods are co-located.
13pub struct HookedModel {
14    pub(crate) inner: Box<DynLanguageModel<'static>>,
15    pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19    pub fn new(
20        inner: Box<DynLanguageModel<'static>>,
21        hooks: Arc<[Arc<dyn GenerationHook>]>,
22    ) -> Self {
23        Self { inner, hooks }
24    }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28    fn provider_name(&self) -> &str {
29        self.inner.provider_name()
30    }
31
32    fn model_id(&self) -> &str {
33        self.inner.model_id()
34    }
35
36    async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37        self.inner.supported_urls().await
38    }
39
40    async fn generate(
41        &self,
42        options: crate::models::language::call_options::LanguageModelCallOptions,
43    ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44    {
45        let result = self.inner.generate(options).await?;
46
47        for hook in self.hooks.iter() {
48            hook.on_generate_result(&result);
49        }
50
51        Ok(result)
52    }
53
54    async fn stream(
55        &self,
56        options: crate::models::language::call_options::LanguageModelCallOptions,
57    ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
58    {
59        let result = self.inner.stream(options).await?;
60
61        let hooked_stream = super::stream::HookedStream::new(result.stream, self.hooks.clone());
62
63        Ok(
64            crate::models::language::stream_result::LanguageModelStreamResult {
65                stream: Box::pin(hooked_stream),
66                request: result.request,
67                response: result.response,
68            },
69        )
70    }
71}