use std::sync::Arc;
use bitrouter_core::{
errors::Result,
models::language::{
language_model::{DynLanguageModel, LanguageModel},
stream_result::LanguageModelStreamResult,
},
routers::{router::LanguageModelRouter, routing_table::RoutingTarget},
sync::HotSwap,
};
use crate::engine::Guardrail;
use crate::guarded_model::GuardedModel;
pub struct GuardedRouter<R> {
inner: R,
guardrail: HotSwap<Guardrail>,
}
impl<R> GuardedRouter<R> {
pub fn new(inner: R, guardrail: Arc<Guardrail>) -> Self {
Self {
inner,
guardrail: HotSwap::from_arc(guardrail),
}
}
pub fn with_hot_swap(inner: R, guardrail: HotSwap<Guardrail>) -> Self {
Self { inner, guardrail }
}
}
impl<R> LanguageModelRouter for GuardedRouter<R>
where
R: std::ops::Deref + Send + Sync,
R::Target: LanguageModelRouter + Send + Sync,
{
async fn route_model(&self, target: RoutingTarget) -> Result<Box<DynLanguageModel<'static>>> {
let model = self.inner.route_model(target).await?;
let guardrail = self.guardrail.load();
if guardrail.is_disabled() {
return Ok(model);
}
Ok(DynLanguageModel::new_box(GuardedModel::new(
model, guardrail,
)))
}
}
impl LanguageModel for GuardedModel {
fn provider_name(&self) -> &str {
self.inner.provider_name()
}
fn model_id(&self) -> &str {
self.inner.model_id()
}
async fn supported_urls(
&self,
) -> bitrouter_core::models::shared::types::Record<String, regex::Regex> {
self.inner.supported_urls().await
}
async fn generate(
&self,
mut options: bitrouter_core::models::language::call_options::LanguageModelCallOptions,
) -> Result<bitrouter_core::models::language::generate_result::LanguageModelGenerateResult>
{
self.guardrail
.inspect_call_options(&mut options)
.map_err(|reason| {
bitrouter_core::errors::BitrouterError::invalid_request(
Some(self.inner.provider_name()),
reason,
None,
)
})?;
let mut result = self.inner.generate(options).await?;
self.guardrail
.inspect_generate_result(&mut result)
.map_err(|reason| {
bitrouter_core::errors::BitrouterError::invalid_response(
Some(self.inner.provider_name()),
reason,
None,
)
})?;
Ok(result)
}
async fn stream(
&self,
mut options: bitrouter_core::models::language::call_options::LanguageModelCallOptions,
) -> Result<LanguageModelStreamResult> {
self.guardrail
.inspect_call_options(&mut options)
.map_err(|reason| {
bitrouter_core::errors::BitrouterError::invalid_request(
Some(self.inner.provider_name()),
reason,
None,
)
})?;
let result = self.inner.stream(options).await?;
let guarded_stream = GuardedStream::new(result.stream, self.guardrail.clone());
Ok(LanguageModelStreamResult {
stream: Box::pin(guarded_stream),
request: result.request,
response: result.response,
})
}
}
use std::pin::Pin;
use std::task::{Context, Poll};
use bitrouter_core::models::language::stream_part::LanguageModelStreamPart;
struct GuardedStream {
inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>>,
guardrail: Arc<Guardrail>,
}
impl GuardedStream {
fn new(
inner: Pin<Box<dyn futures_core::Stream<Item = LanguageModelStreamPart> + Send>>,
guardrail: Arc<Guardrail>,
) -> Self {
Self { inner, guardrail }
}
}
impl futures_core::Stream for GuardedStream {
type Item = LanguageModelStreamPart;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(mut part)) => match self.guardrail.inspect_stream_part(&mut part) {
Ok(_) => Poll::Ready(Some(part)),
Err(reason) => {
tracing::warn!(%reason, "guardrail blocked stream part");
Poll::Ready(Some(LanguageModelStreamPart::Error {
error: serde_json::json!({
"error": {
"message": reason,
"type": "guardrail_blocked",
}
}),
}))
}
},
}
}
}