Skip to main content

bitrouter_core/hooks/
model.rs

1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::GenerationHook;
6
7/// A model wrapper that invokes [`GenerationHook`] callbacks after
8/// `generate()` completes and for each streaming part yielded by `stream()`.
9///
10/// The wrapper is a pure observer — it never modifies requests or responses.
11/// The [`LanguageModel`](crate::models::language::language_model::LanguageModel)
12/// implementation lives below in this file so all trait methods are co-located.
13pub struct HookedModel {
14    pub(crate) inner: Box<DynLanguageModel<'static>>,
15    pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19    pub fn new(
20        inner: Box<DynLanguageModel<'static>>,
21        hooks: Arc<[Arc<dyn GenerationHook>]>,
22    ) -> Self {
23        Self { inner, hooks }
24    }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28    fn provider_name(&self) -> &str {
29        self.inner.provider_name()
30    }
31
32    fn model_id(&self) -> &str {
33        self.inner.model_id()
34    }
35
36    async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37        self.inner.supported_urls().await
38    }
39
40    async fn generate(
41        &self,
42        options: crate::models::language::call_options::LanguageModelCallOptions,
43    ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44    {
45        let result = match self.inner.generate(options).await {
46            Ok(r) => r,
47            Err(e) => {
48                for hook in self.hooks.iter() {
49                    hook.on_generate_error(&e);
50                }
51                return Err(e);
52            }
53        };
54
55        for hook in self.hooks.iter() {
56            hook.on_generate_result(&result);
57        }
58
59        Ok(result)
60    }
61
62    async fn stream(
63        &self,
64        options: crate::models::language::call_options::LanguageModelCallOptions,
65    ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
66    {
67        let result = match self.inner.stream(options).await {
68            Ok(r) => r,
69            Err(e) => {
70                for hook in self.hooks.iter() {
71                    hook.on_generate_error(&e);
72                }
73                return Err(e);
74            }
75        };
76
77        let hooked_stream = super::stream::HookedStream::new(result.stream, self.hooks.clone());
78
79        Ok(
80            crate::models::language::stream_result::LanguageModelStreamResult {
81                stream: Box::pin(hooked_stream),
82                request: result.request,
83                response: result.response,
84            },
85        )
86    }
87}