use crate::core::types::{GenerateOptions, GenerateResult, Prompt, StreamPart};
use crate::core::{LanguageModel, Result};
use async_trait::async_trait;
use futures::stream::BoxStream;
#[async_trait]
pub trait LanguageModelMiddleware: Send + Sync {
async fn transform_params(
&self,
options: GenerateOptions,
) -> Result<GenerateOptions> {
Ok(options)
}
async fn wrap_generate(
&self,
_prompt: &Prompt,
_options: &GenerateOptions,
_model: &dyn LanguageModel,
) -> Option<Result<GenerateResult>> {
None
}
async fn wrap_generate_stream(
&self,
_prompt: &Prompt,
_options: &GenerateOptions,
_model: &dyn LanguageModel,
) -> Option<Result<BoxStream<'static, StreamPart>>> {
None
}
}
pub fn wrap_language_model(
model: Box<dyn LanguageModel>,
middlewares: Vec<Box<dyn LanguageModelMiddleware>>,
) -> Box<dyn LanguageModel> {
let mut wrapped: Box<dyn LanguageModel> = model;
for mw in middlewares.into_iter().rev() {
wrapped = Box::new(WrappedModel {
inner: wrapped,
middleware: mw,
});
}
wrapped
}
struct WrappedModel {
inner: Box<dyn LanguageModel>,
middleware: Box<dyn LanguageModelMiddleware>,
}
#[async_trait]
impl LanguageModel for WrappedModel {
async fn generate(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> Result<GenerateResult> {
let transformed = self.middleware.transform_params(options).await?;
if let Some(result) = self
.middleware
.wrap_generate(&prompt, &transformed, self.inner.as_ref())
.await
{
return result;
}
self.inner.generate(prompt, transformed).await
}
async fn generate_stream(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> Result<BoxStream<'static, StreamPart>> {
let transformed = self.middleware.transform_params(options).await?;
if let Some(result) = self
.middleware
.wrap_generate_stream(&prompt, &transformed, self.inner.as_ref())
.await
{
return result;
}
self.inner.generate_stream(prompt, transformed).await
}
}
pub struct DefaultSettingsMiddleware {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
}
#[async_trait]
impl LanguageModelMiddleware for DefaultSettingsMiddleware {
async fn transform_params(
&self,
mut options: GenerateOptions,
) -> Result<GenerateOptions> {
if options.temperature.is_none() {
options.temperature = self.temperature;
}
if options.max_tokens.is_none() {
options.max_tokens = self.max_tokens;
}
if options.top_p.is_none() {
options.top_p = self.top_p;
}
Ok(options)
}
}
pub struct ExtractReasoningMiddleware {
pub open_tag: String,
pub close_tag: String,
}
impl Default for ExtractReasoningMiddleware {
fn default() -> Self {
Self {
open_tag: "<think>".to_string(),
close_tag: "</think>".to_string(),
}
}
}
#[async_trait]
impl LanguageModelMiddleware for ExtractReasoningMiddleware {
async fn wrap_generate(
&self,
prompt: &Prompt,
options: &GenerateOptions,
model: &dyn LanguageModel,
) -> Option<Result<GenerateResult>> {
let result = model.generate(prompt.clone(), options.clone()).await;
match result {
Ok(mut gen_result) => {
gen_result.text = extract_reasoning(
&gen_result.text,
&self.open_tag,
&self.close_tag,
);
Some(Ok(gen_result))
}
Err(e) => Some(Err(e)),
}
}
}
pub struct ExtractJsonMiddleware;
#[async_trait]
impl LanguageModelMiddleware for ExtractJsonMiddleware {
async fn wrap_generate(
&self,
prompt: &Prompt,
options: &GenerateOptions,
model: &dyn LanguageModel,
) -> Option<Result<GenerateResult>> {
let result = model.generate(prompt.clone(), options.clone()).await;
match result {
Ok(mut gen_result) => {
gen_result.text = extract_json_from_fences(&gen_result.text);
Some(Ok(gen_result))
}
Err(e) => Some(Err(e)),
}
}
}
pub struct SimulateStreamingMiddleware;
#[async_trait]
impl LanguageModelMiddleware for SimulateStreamingMiddleware {
async fn wrap_generate_stream(
&self,
prompt: &Prompt,
options: &GenerateOptions,
model: &dyn LanguageModel,
) -> Option<Result<BoxStream<'static, StreamPart>>> {
let result = model.generate(prompt.clone(), options.clone()).await;
match result {
Ok(gen_result) => {
let text = gen_result.text;
let usage = gen_result.usage;
let finish_reason = gen_result.finish_reason;
let stream = async_stream::stream! {
yield StreamPart::Usage { usage };
for ch in text.chars() {
yield StreamPart::TextDelta { delta: ch.to_string() };
}
yield StreamPart::Finish { finish_reason };
};
Some(Ok(Box::pin(stream)))
}
Err(e) => Some(Err(e)),
}
}
}
fn extract_reasoning(text: &str, open_tag: &str, close_tag: &str) -> String {
let mut result = text.to_string();
while let Some(start) = result.find(open_tag) {
if let Some(end) = result[start..].find(close_tag) {
result = format!(
"{}{}",
&result[..start],
&result[start + end + close_tag.len()..]
);
} else {
break;
}
}
result.trim().to_string()
}
fn extract_json_from_fences(text: &str) -> String {
let trimmed = text.trim();
if trimmed.starts_with("```json") {
if let Some(stripped) = trimmed
.strip_prefix("```json")
.and_then(|s| s.strip_suffix("```"))
{
return stripped.trim().to_string();
}
}
if trimmed.starts_with("```") {
if let Some(stripped) = trimmed
.strip_prefix("```")
.and_then(|s| s.strip_suffix("```"))
{
return stripped.trim().to_string();
}
}
trimmed.to_string()
}