use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::Stream;
use pin_project_lite::pin_project;
use tokio::sync::broadcast::{self, Receiver, Sender};
use crate::error::{Error, Result};
use crate::types::{CompletionRequest, StreamChunk};
const BROADCAST_CHANNEL_CAPACITY: usize = 128;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct RequestHash(u64);
impl RequestHash {
fn from_request(req: &CompletionRequest) -> Self {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
req.model.hash(&mut hasher);
req.temperature
.map(|t| (t * 1000.0) as i32)
.hash(&mut hasher); req.max_tokens.hash(&mut hasher);
for msg in &req.messages {
msg.role.hash(&mut hasher);
msg.text_content().hash(&mut hasher);
}
RequestHash(hasher.finish())
}
}
pin_project! {
pub struct MultiplexedStream {
#[pin]
receiver: Receiver<Arc<StreamChunk>>,
}
}
impl Stream for MultiplexedStream {
type Item = Result<StreamChunk>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.receiver.try_recv() {
Ok(chunk_arc) => {
let chunk = (*chunk_arc).clone();
Poll::Ready(Some(Ok(chunk)))
}
Err(broadcast::error::TryRecvError::Lagged(_)) => {
Poll::Pending
}
Err(broadcast::error::TryRecvError::Empty) => {
Poll::Pending
}
Err(broadcast::error::TryRecvError::Closed) => {
Poll::Ready(None)
}
}
}
}
struct ActiveStream {
sender: Sender<Arc<StreamChunk>>,
subscriber_count: usize,
}
pub struct StreamingMultiplexer {
active_streams: Arc<tokio::sync::Mutex<HashMap<RequestHash, ActiveStream>>>,
}
impl StreamingMultiplexer {
pub fn new() -> Self {
Self {
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
pub async fn subscribe(&self, request: &CompletionRequest) -> Result<MultiplexedStream> {
let hash = RequestHash::from_request(request);
let mut streams = self.active_streams.lock().await;
match streams.get_mut(&hash) {
Some(active) => {
active.subscriber_count += 1;
let receiver = active.sender.subscribe();
Ok(MultiplexedStream { receiver })
}
None => {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
streams.insert(
hash,
ActiveStream {
sender,
subscriber_count: 1,
},
);
Ok(MultiplexedStream { receiver })
}
}
}
pub async fn send_chunk(&self, request: &CompletionRequest, chunk: StreamChunk) -> Result<()> {
let hash = RequestHash::from_request(request);
let streams = self.active_streams.lock().await;
if let Some(active) = streams.get(&hash) {
let arc_chunk = Arc::new(chunk);
active
.sender
.send(arc_chunk)
.map_err(|_| Error::InvalidRequest("Failed to broadcast chunk".to_string()))?;
}
Ok(())
}
pub async fn complete_request(&self, request: &CompletionRequest) {
let hash = RequestHash::from_request(request);
let mut streams = self.active_streams.lock().await;
if let Some(mut active) = streams.remove(&hash) {
active.subscriber_count -= 1;
drop(active.sender);
}
}
pub async fn stats(&self) -> MultiplexerStats {
let streams = self.active_streams.lock().await;
let active_requests = streams.len();
let total_subscribers: usize = streams.values().map(|s| s.subscriber_count).sum();
MultiplexerStats {
active_requests,
total_subscribers,
}
}
}
impl Default for StreamingMultiplexer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct MultiplexerStats {
pub active_requests: usize,
pub total_subscribers: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
fn create_test_request(model: &str) -> CompletionRequest {
CompletionRequest::new(model, vec![Message::user("test")])
}
#[test]
fn test_request_hash_same_request() {
let req1 = create_test_request("openai/gpt-4");
let req2 = create_test_request("openai/gpt-4");
assert_eq!(
RequestHash::from_request(&req1),
RequestHash::from_request(&req2)
);
}
#[test]
fn test_request_hash_different_model() {
let req1 = create_test_request("openai/gpt-4");
let req2 = create_test_request("anthropic/claude-sonnet");
assert_ne!(
RequestHash::from_request(&req1),
RequestHash::from_request(&req2)
);
}
#[test]
fn test_request_hash_different_message() {
let mut req1 = create_test_request("openai/gpt-4");
let req2 = create_test_request("openai/gpt-4");
req1.messages.push(Message::user("extra"));
assert_ne!(
RequestHash::from_request(&req1),
RequestHash::from_request(&req2)
);
}
#[test]
fn test_request_hash_different_temperature() {
let mut req1 = create_test_request("openai/gpt-4");
let mut req2 = create_test_request("openai/gpt-4");
req1.temperature = Some(0.5);
req2.temperature = Some(1.0);
assert_ne!(
RequestHash::from_request(&req1),
RequestHash::from_request(&req2)
);
}
#[tokio::test]
async fn test_multiplexer_new() {
let multiplexer = StreamingMultiplexer::new();
let stats = multiplexer.stats().await;
assert_eq!(stats.active_requests, 0);
assert_eq!(stats.total_subscribers, 0);
}
#[tokio::test]
async fn test_multiplexer_duplicate_detection() {
let multiplexer = StreamingMultiplexer::new();
let request = create_test_request("openai/gpt-4");
let _stream1 = multiplexer.subscribe(&request).await.unwrap();
let stats = multiplexer.stats().await;
assert_eq!(stats.active_requests, 1);
assert_eq!(stats.total_subscribers, 1);
let _stream2 = multiplexer.subscribe(&request).await.unwrap();
let stats = multiplexer.stats().await;
assert_eq!(stats.active_requests, 1);
assert_eq!(stats.total_subscribers, 2);
}
#[tokio::test]
async fn test_multiplexer_different_requests() {
let multiplexer = StreamingMultiplexer::new();
let req1 = create_test_request("openai/gpt-4");
let req2 = create_test_request("anthropic/claude-sonnet");
let _stream1 = multiplexer.subscribe(&req1).await.unwrap();
let _stream2 = multiplexer.subscribe(&req2).await.unwrap();
let stats = multiplexer.stats().await;
assert_eq!(stats.active_requests, 2);
assert_eq!(stats.total_subscribers, 2);
}
#[tokio::test]
async fn test_multiplexer_broadcast_chunk() {
let multiplexer = StreamingMultiplexer::new();
let request = create_test_request("openai/gpt-4");
let mut stream = multiplexer.subscribe(&request).await.unwrap();
let chunk = StreamChunk {
event_type: crate::types::StreamEventType::ContentBlockDelta,
index: Some(0),
delta: Some(crate::types::ContentDelta::Text {
text: "hello".to_string(),
}),
stop_reason: None,
usage: None,
};
multiplexer.send_chunk(&request, chunk).await.unwrap();
use futures::StreamExt;
if let Some(Ok(received)) = stream.next().await {
match received.delta {
Some(crate::types::ContentDelta::Text { text }) => {
assert_eq!(text, "hello");
}
other => {
panic!("Expected text delta, got {:?}", other);
}
}
} else {
panic!("Failed to receive chunk from multiplexer");
}
}
}