use std::sync::Arc;
use async_trait::async_trait;
use crate::error::Result;
use crate::language_model::{
CallOptions, GenerateResult, LanguageModel, StreamResult, SupportedUrls,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CallKind {
Generate,
Stream,
}
#[async_trait]
pub trait LanguageModelMiddleware: Send + Sync + std::fmt::Debug {
fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
None
}
fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
None
}
async fn override_supported_urls(&self, _inner: &dyn LanguageModel) -> Option<SupportedUrls> {
None
}
async fn transform_params(
&self,
_kind: CallKind,
params: CallOptions,
_inner: &dyn LanguageModel,
) -> Result<CallOptions> {
Ok(params)
}
async fn wrap_generate(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<GenerateResult> {
next.do_generate(params).await
}
async fn wrap_stream(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<StreamResult> {
next.do_stream(params).await
}
}
pub fn wrap_language_model<I>(
model: Arc<dyn LanguageModel>,
middleware: I,
) -> Arc<dyn LanguageModel>
where
I: IntoIterator<Item = Arc<dyn LanguageModelMiddleware>>,
{
let mut layers: Vec<Arc<dyn LanguageModelMiddleware>> = middleware.into_iter().collect();
layers.reverse();
layers
.into_iter()
.fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
}
struct Wrapped {
inner: Arc<dyn LanguageModel>,
middleware: Arc<dyn LanguageModelMiddleware>,
provider: String,
model_id: String,
}
impl Wrapped {
fn new(inner: Arc<dyn LanguageModel>, middleware: Arc<dyn LanguageModelMiddleware>) -> Self {
let provider = middleware
.override_provider(inner.as_ref())
.unwrap_or_else(|| inner.provider().to_owned());
let model_id = middleware
.override_model_id(inner.as_ref())
.unwrap_or_else(|| inner.model_id().to_owned());
Self {
inner,
middleware,
provider,
model_id,
}
}
}
impl std::fmt::Debug for Wrapped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wrapped")
.field("provider", &self.provider)
.field("model_id", &self.model_id)
.field("middleware", &self.middleware)
.field("inner", &self.inner)
.finish()
}
}
#[async_trait]
impl LanguageModel for Wrapped {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn supported_urls(&self) -> SupportedUrls {
if let Some(custom) = self
.middleware
.override_supported_urls(self.inner.as_ref())
.await
{
return custom;
}
self.inner.supported_urls().await
}
async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
let transformed = self
.middleware
.transform_params(CallKind::Generate, options, self.inner.as_ref())
.await?;
self.middleware
.wrap_generate(self.inner.as_ref(), transformed)
.await
}
async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
let transformed = self
.middleware
.transform_params(CallKind::Stream, options, self.inner.as_ref())
.await?;
self.middleware
.wrap_stream(self.inner.as_ref(), transformed)
.await
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use futures::StreamExt;
use futures::stream;
use crate::language_model::{FinishReason, FinishReasonKind, StreamPart, Usage};
use super::*;
#[derive(Debug, Default)]
struct MockModel {
provider: String,
model_id: String,
generate_calls: AtomicUsize,
stream_calls: AtomicUsize,
last_params: Mutex<Option<CallOptions>>,
}
impl MockModel {
fn new(provider: &str, model_id: &str) -> Self {
Self {
provider: provider.to_owned(),
model_id: model_id.to_owned(),
generate_calls: AtomicUsize::new(0),
stream_calls: AtomicUsize::new(0),
last_params: Mutex::new(None),
}
}
fn generate_count(&self) -> usize {
self.generate_calls.load(Ordering::SeqCst)
}
fn stream_count(&self) -> usize {
self.stream_calls.load(Ordering::SeqCst)
}
fn last_temperature(&self) -> Option<f32> {
self.last_params
.lock()
.expect("mock mutex poisoned")
.as_ref()
.and_then(|p| p.temperature)
}
}
#[async_trait]
impl LanguageModel for MockModel {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
self.generate_calls.fetch_add(1, Ordering::SeqCst);
*self.last_params.lock().expect("mock mutex poisoned") = Some(options);
Ok(GenerateResult {
content: vec![],
finish_reason: FinishReason::new(FinishReasonKind::Stop),
usage: Usage::default(),
provider_metadata: None,
request: None,
response: None,
warnings: vec![],
})
}
async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
self.stream_calls.fetch_add(1, Ordering::SeqCst);
*self.last_params.lock().expect("mock mutex poisoned") = Some(options);
let parts = stream::iter(vec![
Ok(StreamPart::StreamStart { warnings: vec![] }),
Ok(StreamPart::Finish {
usage: Usage::default(),
finish_reason: FinishReason::new(FinishReasonKind::Stop),
provider_metadata: None,
}),
]);
Ok(StreamResult {
stream: Box::pin(parts),
request: None,
response: None,
})
}
}
#[derive(Debug)]
struct OverrideAndTransform;
#[async_trait]
impl LanguageModelMiddleware for OverrideAndTransform {
fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
Some("wrapped-provider".to_owned())
}
fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
Some("wrapped-model".to_owned())
}
async fn transform_params(
&self,
_kind: CallKind,
mut params: CallOptions,
_inner: &dyn LanguageModel,
) -> Result<CallOptions> {
params.temperature = Some(params.temperature.unwrap_or(0.0) + 1.0);
Ok(params)
}
}
#[derive(Debug)]
struct OrderRecorder {
label: &'static str,
log: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl LanguageModelMiddleware for OrderRecorder {
async fn wrap_generate(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<GenerateResult> {
self.log
.lock()
.expect("log mutex poisoned")
.push(format!("{}:enter", self.label));
let res = next.do_generate(params).await;
self.log
.lock()
.expect("log mutex poisoned")
.push(format!("{}:exit", self.label));
res
}
}
#[derive(Debug)]
struct StreamFromGenerate;
#[async_trait]
impl LanguageModelMiddleware for StreamFromGenerate {
async fn wrap_stream(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<StreamResult> {
let _ = next.do_generate(params).await?;
Ok(StreamResult {
stream: Box::pin(stream::iter(vec![])),
request: None,
response: None,
})
}
}
#[tokio::test]
async fn empty_middleware_returns_model_unchanged() {
let model = Arc::new(MockModel::new("openai", "gpt-foo"));
let wrapped: Arc<dyn LanguageModel> =
wrap_language_model(Arc::clone(&model) as _, Vec::new());
assert_eq!(wrapped.provider(), "openai");
assert_eq!(wrapped.model_id(), "gpt-foo");
wrapped
.do_generate(CallOptions::default())
.await
.expect("generate succeeded");
assert_eq!(model.generate_count(), 1);
}
#[tokio::test]
async fn overrides_replace_identity_and_transform_mutates_params() {
let model = Arc::new(MockModel::new("openai", "gpt-foo"));
let wrapped = wrap_language_model(
Arc::clone(&model) as _,
[Arc::new(OverrideAndTransform) as Arc<dyn LanguageModelMiddleware>],
);
assert_eq!(wrapped.provider(), "wrapped-provider");
assert_eq!(wrapped.model_id(), "wrapped-model");
wrapped
.do_generate(CallOptions::default())
.await
.expect("generate succeeded");
assert_eq!(model.last_temperature(), Some(1.0));
}
#[tokio::test]
async fn wrap_order_runs_first_middleware_outermost() {
let model = Arc::new(MockModel::new("openai", "gpt-foo"));
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let m1 = Arc::new(OrderRecorder {
label: "m1",
log: Arc::clone(&log),
}) as Arc<dyn LanguageModelMiddleware>;
let m2 = Arc::new(OrderRecorder {
label: "m2",
log: Arc::clone(&log),
}) as Arc<dyn LanguageModelMiddleware>;
let wrapped = wrap_language_model(model, [m1, m2]);
wrapped
.do_generate(CallOptions::default())
.await
.expect("generate succeeded");
let entries = log.lock().expect("log mutex poisoned").clone();
assert_eq!(
entries,
vec!["m1:enter", "m2:enter", "m2:exit", "m1:exit"],
"first middleware must be outermost",
);
}
#[tokio::test]
async fn middleware_can_swap_call_kind_via_next() {
let model = Arc::new(MockModel::new("openai", "gpt-foo"));
let wrapped = wrap_language_model(
Arc::clone(&model) as _,
[Arc::new(StreamFromGenerate) as Arc<dyn LanguageModelMiddleware>],
);
let mut stream = wrapped
.do_stream(CallOptions::default())
.await
.expect("stream succeeded")
.stream;
assert!(stream.next().await.is_none());
assert_eq!(model.generate_count(), 1, "do_generate was used internally");
assert_eq!(model.stream_count(), 0, "do_stream on inner was bypassed");
}
}