Skip to main content

synaptic_models/
rate_limit.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapticError};
5use tokio::sync::Semaphore;
6
7pub struct RateLimitedChatModel {
8    inner: Arc<dyn ChatModel>,
9    semaphore: Arc<Semaphore>,
10}
11
12impl RateLimitedChatModel {
13    pub fn new(inner: Arc<dyn ChatModel>, max_concurrent: usize) -> Self {
14        Self {
15            inner,
16            semaphore: Arc::new(Semaphore::new(max_concurrent)),
17        }
18    }
19}
20
21#[async_trait]
22impl ChatModel for RateLimitedChatModel {
23    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
24        let _permit = self
25            .semaphore
26            .acquire()
27            .await
28            .map_err(|e| SynapticError::Model(format!("semaphore error: {e}")))?;
29        self.inner.chat(request).await
30    }
31
32    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
33        let inner = self.inner.clone();
34        let semaphore = self.semaphore.clone();
35
36        Box::pin(async_stream::stream! {
37            let _permit = match semaphore.acquire_owned().await {
38                Ok(p) => p,
39                Err(e) => {
40                    yield Err(SynapticError::Model(format!("semaphore error: {e}")));
41                    return;
42                }
43            };
44
45            use futures::StreamExt;
46            let mut stream = inner.stream_chat(request);
47            while let Some(result) = stream.next().await {
48                yield result;
49            }
50        })
51    }
52}