Skip to main content

bitrouter_core/hooks/
router.rs

1use std::sync::Arc;
2
3use crate::{
4    errors::Result,
5    models::language::language_model::DynLanguageModel,
6    routers::{model_router::LanguageModelRouter, routing_table::RoutingTarget},
7};
8
9use super::{GenerationHook, HookedModel};
10
11/// A [`LanguageModelRouter`] wrapper that attaches [`GenerationHook`]s to
12/// every model returned by the inner router.
13///
14/// When the hooks slice is empty the wrapper is a zero-cost pass-through —
15/// it returns the inner model unchanged.
16pub struct HookedRouter<R> {
17    inner: R,
18    hooks: Arc<[Arc<dyn GenerationHook>]>,
19}
20
21impl<R> HookedRouter<R> {
22    /// Wrap an existing router with generation hooks.
23    ///
24    /// If `hooks` is empty, models are returned unwrapped.
25    pub fn new(inner: R, hooks: Arc<[Arc<dyn GenerationHook>]>) -> Self {
26        Self { inner, hooks }
27    }
28}
29
30impl<R> LanguageModelRouter for HookedRouter<R>
31where
32    R: std::ops::Deref + Send + Sync,
33    R::Target: LanguageModelRouter + Send + Sync,
34{
35    async fn route_model(&self, target: RoutingTarget) -> Result<Box<DynLanguageModel<'static>>> {
36        let model = self.inner.route_model(target).await?;
37
38        if self.hooks.is_empty() {
39            return Ok(model);
40        }
41
42        Ok(DynLanguageModel::new_box(HookedModel::new(
43            model,
44            self.hooks.clone(),
45        )))
46    }
47}