pub mod approval_gate;
pub mod context_editing;
pub mod context_injection;
pub mod filesystem;
pub mod human_in_the_loop;
pub mod model_call_limit;
pub mod model_fallback;
pub mod model_retry;
pub mod patch_tool_calls;
pub mod pii;
pub mod planning;
pub mod prompt_caching;
pub mod rate_limit;
pub mod recovery;
pub mod redaction;
pub mod subagent;
pub mod summarization;
pub mod todo;
pub mod token_counter;
pub mod tool_call_limit;
pub mod tool_emulator;
pub mod tool_retry;
pub mod tool_selection;
#[cfg(test)]
pub(crate) mod tests_util;
use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::{Message, Result};
use cognis_llm::chat::{ChatOptions, ChatResponse};
use cognis_llm::tools::ToolDefinition;
use cognis_llm::Client;
pub use approval_gate::{ApprovalGate, AutoApproveAll, AutoRejectAll, ChatApproval, ChatApprover};
pub use context_editing::{CapMessageLength, ContextEditing, DropMatching, EditPolicy};
pub use context_injection::{ContextInjection, ContextProvider, FnContextProvider};
pub use filesystem::{FilesystemMiddleware, WorkspaceLister};
pub use human_in_the_loop::{AlwaysSkip, HumanDecision, HumanInTheLoop, HumanResponder};
pub use model_call_limit::ModelCallLimit;
pub use model_fallback::ModelFallback;
pub use model_retry::ModelRetry;
pub use patch_tool_calls::{FnToolCallPatcher, PatchToolCalls, ToolCallPatcher};
pub use pii::PiiRedactor;
pub use planning::Planning;
pub use prompt_caching::PromptCaching;
pub use rate_limit::{
CompositeLimiter, CostBasedLimiter, RateLimit, RateLimiter, SlidingWindowLimiter, TokenBucket,
};
pub use recovery::{FixedRecovery, FnRecovery, Recovery, RecoveryStrategy};
pub use redaction::RegexRedactor;
pub use subagent::{SubagentMiddleware, SubagentRouter};
pub use summarization::Summarization;
pub use todo::TodoMiddleware;
pub use token_counter::TokenCounter;
pub use tool_call_limit::ToolCallLimit;
pub use tool_emulator::{EmulatorSource, MapEmulator, ToolEmulator};
pub use tool_retry::{ToolRetry, ToolRetryClassifier};
pub use tool_selection::{LimitTools, ToolAllowList, ToolDenyList, ToolFilter, ToolSelection};
#[derive(Debug, Clone)]
pub struct MiddlewareCtx {
pub messages: Vec<Message>,
pub tool_defs: Vec<ToolDefinition>,
pub opts: ChatOptions,
}
impl MiddlewareCtx {
pub fn new(messages: Vec<Message>, tool_defs: Vec<ToolDefinition>, opts: ChatOptions) -> Self {
Self {
messages,
tool_defs,
opts,
}
}
}
#[async_trait]
pub trait Next: Send + Sync {
async fn invoke(&self, ctx: MiddlewareCtx) -> Result<ChatResponse>;
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn call(&self, ctx: MiddlewareCtx, next: Arc<dyn Next>) -> Result<ChatResponse>;
fn name(&self) -> &str {
"Middleware"
}
}
pub struct MiddlewarePipeline {
layers: Vec<Arc<dyn Middleware>>,
}
impl Default for MiddlewarePipeline {
fn default() -> Self {
Self::new()
}
}
impl MiddlewarePipeline {
pub fn new() -> Self {
Self { layers: Vec::new() }
}
pub fn push(mut self, m: impl Middleware + 'static) -> Self {
self.layers.push(Arc::new(m));
self
}
pub fn push_boxed(mut self, m: Arc<dyn Middleware>) -> Self {
self.layers.push(m);
self
}
pub fn build(self, client: Client) -> PipelinedClient {
PipelinedClient {
client,
layers: self.layers,
}
}
}
#[derive(Clone)]
pub struct PipelinedClient {
client: Client,
layers: Vec<Arc<dyn Middleware>>,
}
impl PipelinedClient {
pub async fn invoke(
&self,
messages: Vec<Message>,
tool_defs: Vec<ToolDefinition>,
opts: ChatOptions,
) -> Result<ChatResponse> {
let ctx = MiddlewareCtx::new(messages, tool_defs, opts);
let next: Arc<dyn Next> = Arc::new(ClientNext {
client: self.client.clone(),
});
let chained = self
.layers
.iter()
.rev()
.fold(next, |acc, layer| -> Arc<dyn Next> {
Arc::new(LayerNext {
layer: layer.clone(),
next: acc,
})
});
chained.invoke(ctx).await
}
pub fn client(&self) -> &Client {
&self.client
}
}
struct ClientNext {
client: Client,
}
#[async_trait]
impl Next for ClientNext {
async fn invoke(&self, ctx: MiddlewareCtx) -> Result<ChatResponse> {
self.client
.provider()
.chat_completion_with_tools(ctx.messages, ctx.tool_defs, ctx.opts)
.await
}
}
struct LayerNext {
layer: Arc<dyn Middleware>,
next: Arc<dyn Next>,
}
#[async_trait]
impl Next for LayerNext {
async fn invoke(&self, ctx: MiddlewareCtx) -> Result<ChatResponse> {
self.layer.call(ctx, self.next.clone()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use cognis_llm::chat::{HealthStatus, StreamChunk, Usage};
use cognis_llm::provider::{LLMProvider, Provider};
struct Echo {
calls: Arc<AtomicUsize>,
}
#[async_trait]
impl LLMProvider for Echo {
fn name(&self) -> &str {
"echo"
}
fn provider_type(&self) -> Provider {
Provider::Ollama
}
async fn chat_completion(
&self,
messages: Vec<Message>,
_opts: ChatOptions,
) -> Result<ChatResponse> {
self.calls.fetch_add(1, Ordering::SeqCst);
let last = messages.last().cloned().unwrap_or(Message::ai(""));
Ok(ChatResponse {
message: Message::ai(format!("echo: {}", last.content())),
usage: Some(Usage::default()),
finish_reason: "stop".into(),
model: "echo".into(),
})
}
async fn chat_completion_stream(
&self,
_: Vec<Message>,
_: ChatOptions,
) -> Result<cognis_core::RunnableStream<StreamChunk>> {
unimplemented!()
}
async fn health_check(&self) -> Result<HealthStatus> {
Ok(HealthStatus::Healthy { latency_ms: 0 })
}
}
struct CountingMw {
seen: Arc<AtomicUsize>,
}
#[async_trait]
impl Middleware for CountingMw {
async fn call(&self, ctx: MiddlewareCtx, next: Arc<dyn Next>) -> Result<ChatResponse> {
self.seen.fetch_add(1, Ordering::SeqCst);
next.invoke(ctx).await
}
fn name(&self) -> &str {
"Counting"
}
}
fn client() -> (Client, Arc<AtomicUsize>) {
let calls = Arc::new(AtomicUsize::new(0));
(
Client::new(Arc::new(Echo {
calls: calls.clone(),
})),
calls,
)
}
#[tokio::test]
async fn empty_pipeline_passes_through() {
let (c, calls) = client();
let pipe = MiddlewarePipeline::new().build(c);
let resp = pipe
.invoke(
vec![Message::human("hi")],
Vec::new(),
ChatOptions::default(),
)
.await
.unwrap();
assert_eq!(resp.message.content(), "echo: hi");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn layers_run_in_reverse_push_order_and_each_sees_call() {
let (c, calls) = client();
let a = Arc::new(AtomicUsize::new(0));
let b = Arc::new(AtomicUsize::new(0));
let pipe = MiddlewarePipeline::new()
.push(CountingMw { seen: a.clone() })
.push(CountingMw { seen: b.clone() })
.build(c);
let _ = pipe
.invoke(
vec![Message::human("hi")],
Vec::new(),
ChatOptions::default(),
)
.await
.unwrap();
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
}