use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::{
BaseChatModel, ChatStream, ModelProfile, ToolChoice,
};
use cognis_core::messages::Message;
use cognis_core::outputs::ChatResult;
use cognis_core::tools::ToolSchema;
pub struct TokenBucket {
capacity: f64,
tokens: Arc<Mutex<f64>>,
refill_rate: f64,
last_refill: Arc<Mutex<Instant>>,
}
impl TokenBucket {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
capacity,
tokens: Arc::new(Mutex::new(capacity)),
refill_rate,
last_refill: Arc::new(Mutex::new(Instant::now())),
}
}
pub async fn acquire(&self) {
loop {
self.refill();
let acquired = self.try_acquire_token();
if acquired {
return;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
fn try_acquire_token(&self) -> bool {
let mut tokens = self.tokens.lock().unwrap();
if *tokens >= 1.0 {
*tokens -= 1.0;
true
} else {
false
}
}
fn refill(&self) {
let mut last = self.last_refill.lock().unwrap();
let now = Instant::now();
let elapsed = now.duration_since(*last).as_secs_f64();
if elapsed > 0.0 {
let mut tokens = self.tokens.lock().unwrap();
*tokens = (*tokens + elapsed * self.refill_rate).min(self.capacity);
*last = now;
}
}
#[cfg(test)]
fn available(&self) -> f64 {
self.refill();
*self.tokens.lock().unwrap()
}
}
pub struct RateLimitedChatModel {
inner: Box<dyn BaseChatModel>,
request_limiter: TokenBucket,
#[allow(dead_code)]
token_limiter: Option<TokenBucket>,
}
impl RateLimitedChatModel {
pub fn new(inner: Box<dyn BaseChatModel>, requests_per_minute: f64) -> Self {
let refill_rate = requests_per_minute / 60.0;
let capacity = requests_per_minute;
Self {
inner,
request_limiter: TokenBucket::new(capacity, refill_rate),
token_limiter: None,
}
}
pub fn with_rate(
inner: Box<dyn BaseChatModel>,
max_requests_per_second: f64,
burst_size: usize,
) -> Self {
Self {
inner,
request_limiter: TokenBucket::new(burst_size as f64, max_requests_per_second),
token_limiter: None,
}
}
pub fn with_token_limit(mut self, tokens_per_minute: f64) -> Self {
let refill_rate = tokens_per_minute / 60.0;
self.token_limiter = Some(TokenBucket::new(tokens_per_minute, refill_rate));
self
}
}
#[async_trait]
impl BaseChatModel for RateLimitedChatModel {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult> {
self.request_limiter.acquire().await;
self.inner._generate(messages, stop).await
}
fn llm_type(&self) -> &str {
self.inner.llm_type()
}
async fn _stream(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatStream> {
self.request_limiter.acquire().await;
self.inner._stream(messages, stop).await
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
let bound_inner = self.inner.bind_tools(tools, tool_choice)?;
let rpm = self.request_limiter.capacity;
let mut wrapped = RateLimitedChatModel::new(bound_inner, rpm);
if let Some(ref tl) = self.token_limiter {
wrapped.token_limiter = Some(TokenBucket::new(tl.capacity, tl.refill_rate));
}
Ok(Box::new(wrapped))
}
fn profile(&self) -> ModelProfile {
self.inner.profile()
}
}
pub fn with_rate_limit(
model: Box<dyn BaseChatModel>,
requests_per_minute: f64,
) -> Box<dyn BaseChatModel> {
Box::new(RateLimitedChatModel::new(model, requests_per_minute))
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::chat_model::ModelProfile;
use cognis_core::messages::{AIMessage, Message};
use cognis_core::outputs::{ChatGeneration, ChatResult};
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockChatModel {
call_count: Arc<AtomicUsize>,
supports_tools: bool,
}
impl MockChatModel {
fn new() -> Self {
Self {
call_count: Arc::new(AtomicUsize::new(0)),
supports_tools: false,
}
}
fn with_tools() -> Self {
Self {
call_count: Arc::new(AtomicUsize::new(0)),
supports_tools: true,
}
}
fn count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl BaseChatModel for MockChatModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(ChatResult {
generations: vec![ChatGeneration {
text: "Hello!".to_string(),
message: Message::Ai(AIMessage::new("Hello!")),
generation_info: None,
}],
llm_output: None,
})
}
fn llm_type(&self) -> &str {
"mock"
}
fn bind_tools(
&self,
_tools: &[ToolSchema],
_tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
if self.supports_tools {
Ok(Box::new(MockChatModel::with_tools()))
} else {
Err(cognis_core::error::CognisError::NotImplemented(
"mock does not support tool binding".into(),
))
}
}
fn profile(&self) -> ModelProfile {
ModelProfile {
max_input_tokens: Some(128_000),
tool_calling: Some(true),
..Default::default()
}
}
}
#[tokio::test]
async fn test_token_bucket_allows_within_capacity() {
let bucket = TokenBucket::new(5.0, 100.0);
for _ in 0..5 {
bucket.acquire().await;
}
assert!(bucket.available() < 1.0);
}
#[tokio::test]
async fn test_token_bucket_refills_over_time() {
let bucket = TokenBucket::new(10.0, 1000.0);
for _ in 0..10 {
bucket.acquire().await;
}
assert!(bucket.available() < 1.0);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(bucket.available() >= 1.0);
}
#[tokio::test]
async fn test_token_bucket_does_not_exceed_capacity() {
let bucket = TokenBucket::new(5.0, 1000.0);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(bucket.available() <= 5.0 + 0.01); }
#[tokio::test]
async fn test_rate_limited_model_delegates_generate() {
let mock = MockChatModel::new();
let count = mock.call_count.clone();
let limited = RateLimitedChatModel::new(Box::new(mock), 600.0);
let messages = vec![Message::Human(cognis_core::messages::HumanMessage::new(
"Hi",
))];
let result = limited._generate(&messages, None).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 1);
assert_eq!(result.unwrap().generations[0].text, "Hello!");
}
#[tokio::test]
async fn test_rate_limited_model_delays_when_exhausted() {
let bucket = TokenBucket::new(1.0, 0.1);
let start = tokio::time::Instant::now();
bucket.acquire().await;
let first_elapsed = start.elapsed();
assert!(first_elapsed < Duration::from_millis(50));
let bucket = TokenBucket::new(1.0, 100.0); bucket.acquire().await;
let start = tokio::time::Instant::now();
bucket.acquire().await; let second_elapsed = start.elapsed();
assert!(second_elapsed >= Duration::from_millis(5));
}
#[tokio::test]
async fn test_rate_limited_model_llm_type() {
let mock = MockChatModel::new();
let limited = RateLimitedChatModel::new(Box::new(mock), 60.0);
assert_eq!(limited.llm_type(), "mock");
}
#[tokio::test]
async fn test_rate_limited_model_profile_delegates() {
let mock = MockChatModel::new();
let limited = RateLimitedChatModel::new(Box::new(mock), 60.0);
let profile = limited.profile();
assert_eq!(profile.max_input_tokens, Some(128_000));
assert_eq!(profile.tool_calling, Some(true));
}
#[tokio::test]
async fn test_bind_tools_preserves_rate_limiting() {
let mock = MockChatModel::with_tools();
let limited = RateLimitedChatModel::new(Box::new(mock), 60.0);
let tools: Vec<ToolSchema> = vec![];
let result = limited.bind_tools(&tools, None);
assert!(result.is_ok());
let bound = result.unwrap();
let messages = vec![Message::Human(cognis_core::messages::HumanMessage::new(
"Hi",
))];
let gen_result = bound._generate(&messages, None).await;
assert!(gen_result.is_ok());
}
#[tokio::test]
async fn test_with_rate_limit_convenience() {
let mock = MockChatModel::new();
let limited = with_rate_limit(Box::new(mock), 120.0);
assert_eq!(limited.llm_type(), "mock");
let messages = vec![Message::Human(cognis_core::messages::HumanMessage::new(
"test",
))];
let result = limited._generate(&messages, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_with_rate_constructor() {
let mock = MockChatModel::new();
let count = mock.call_count.clone();
let limited = RateLimitedChatModel::with_rate(Box::new(mock), 10.0, 5);
let messages = vec![Message::Human(cognis_core::messages::HumanMessage::new(
"Hi",
))];
let result = limited._generate(&messages, None).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_with_token_limit_builder() {
let mock = MockChatModel::new();
let limited = RateLimitedChatModel::new(Box::new(mock), 60.0).with_token_limit(100_000.0);
assert!(limited.token_limiter.is_some());
let messages = vec![Message::Human(cognis_core::messages::HumanMessage::new(
"test",
))];
let result = limited._generate(&messages, None).await;
assert!(result.is_ok());
}
}