Skip to main content

bitrouter_core/hooks/
model.rs

1use std::sync::Arc;
2
3use crate::models::language::language_model::DynLanguageModel;
4
5use super::{GenerationContext, GenerationHook};
6
7/// A model wrapper that invokes [`GenerationHook`] callbacks after
8/// `generate()` completes and for each streaming part yielded by `stream()`.
9///
10/// The wrapper is a pure observer — it never modifies requests or responses.
11/// The [`LanguageModel`](crate::models::language::language_model::LanguageModel)
12/// implementation lives below in this file so all trait methods are co-located.
13pub struct HookedModel {
14    pub(crate) inner: Box<DynLanguageModel<'static>>,
15    pub(crate) hooks: Arc<[Arc<dyn GenerationHook>]>,
16}
17
18impl HookedModel {
19    pub fn new(
20        inner: Box<DynLanguageModel<'static>>,
21        hooks: Arc<[Arc<dyn GenerationHook>]>,
22    ) -> Self {
23        Self { inner, hooks }
24    }
25}
26
27impl crate::models::language::language_model::LanguageModel for HookedModel {
28    fn provider_name(&self) -> &str {
29        self.inner.provider_name()
30    }
31
32    fn model_id(&self) -> &str {
33        self.inner.model_id()
34    }
35
36    async fn supported_urls(&self) -> crate::models::shared::types::Record<String, regex::Regex> {
37        self.inner.supported_urls().await
38    }
39
40    async fn generate(
41        &self,
42        options: crate::models::language::call_options::LanguageModelCallOptions,
43    ) -> crate::errors::Result<crate::models::language::generate_result::LanguageModelGenerateResult>
44    {
45        let result = match self.inner.generate(options).await {
46            Ok(r) => r,
47            Err(e) => {
48                for hook in self.hooks.iter() {
49                    hook.on_generate_error(&e);
50                }
51                return Err(e);
52            }
53        };
54
55        let ctx = GenerationContext {
56            model_id: self.inner.model_id(),
57            provider_name: self.inner.provider_name(),
58        };
59
60        for hook in self.hooks.iter() {
61            hook.on_generate_result(&ctx, &result);
62        }
63
64        Ok(result)
65    }
66
67    async fn stream(
68        &self,
69        options: crate::models::language::call_options::LanguageModelCallOptions,
70    ) -> crate::errors::Result<crate::models::language::stream_result::LanguageModelStreamResult>
71    {
72        let result = match self.inner.stream(options).await {
73            Ok(r) => r,
74            Err(e) => {
75                for hook in self.hooks.iter() {
76                    hook.on_generate_error(&e);
77                }
78                return Err(e);
79            }
80        };
81
82        let hooked_stream = super::stream::HookedStream::new(
83            result.stream,
84            self.hooks.clone(),
85            self.inner.model_id().to_owned(),
86            self.inner.provider_name().to_owned(),
87        );
88
89        Ok(
90            crate::models::language::stream_result::LanguageModelStreamResult {
91                stream: Box::pin(hooked_stream),
92                request: result.request,
93                response: result.response,
94            },
95        )
96    }
97}