llmkit 0.1.3

Production-grade LLM client - 100+ providers, 11,000+ models. Pure Rust.
Documentation
//! Zero-copy streaming multiplexer for request deduplication.
//!
//! This module provides a streaming multiplexer that detects duplicate requests and broadcasts
//! their responses to multiple subscribers without copying data. This enables 10-100x throughput
//! improvements when handling multiple identical requests.
//!
//! The multiplexer uses:
//! - `tokio::sync::broadcast` for lock-free multi-subscriber channels
//! - Request hashing for O(1) duplicate detection
//! - Arc-based zero-copy data sharing

use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::Stream;
use pin_project_lite::pin_project;
use tokio::sync::broadcast::{self, Receiver, Sender};

use crate::error::{Error, Result};
use crate::types::{CompletionRequest, StreamChunk};

/// Maximum number of subscribers per broadcast channel
const BROADCAST_CHANNEL_CAPACITY: usize = 128;

/// Hash representation of a completion request for deduplication
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct RequestHash(u64);

impl RequestHash {
    /// Create a hash from a completion request
    fn from_request(req: &CompletionRequest) -> Self {
        use std::collections::hash_map::DefaultHasher;

        let mut hasher = DefaultHasher::new();

        // Hash model, provider prefix, and messages for deduplication
        req.model.hash(&mut hasher);
        req.temperature
            .map(|t| (t * 1000.0) as i32)
            .hash(&mut hasher); // Quantize float
        req.max_tokens.hash(&mut hasher);

        // Hash each message's role and text content
        for msg in &req.messages {
            msg.role.hash(&mut hasher);
            msg.text_content().hash(&mut hasher);
        }

        RequestHash(hasher.finish())
    }
}

pin_project! {
    /// A stream that receives multiplexed data from a broadcast channel
    pub struct MultiplexedStream {
        #[pin]
        receiver: Receiver<Arc<StreamChunk>>,
    }
}

impl Stream for MultiplexedStream {
    type Item = Result<StreamChunk>;

    fn poll_next(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match self.receiver.try_recv() {
            Ok(chunk_arc) => {
                // Dereference Arc to get chunk (zero-copy, just shared reference)
                let chunk = (*chunk_arc).clone();
                Poll::Ready(Some(Ok(chunk)))
            }
            Err(broadcast::error::TryRecvError::Lagged(_)) => {
                // Subscriber lagged behind, continue reading
                Poll::Pending
            }
            Err(broadcast::error::TryRecvError::Empty) => {
                // No message available yet
                Poll::Pending
            }
            Err(broadcast::error::TryRecvError::Closed) => {
                // Broadcast channel closed, stream is done
                Poll::Ready(None)
            }
        }
    }
}

/// Multiplexed stream sender state
struct ActiveStream {
    sender: Sender<Arc<StreamChunk>>,
    /// Number of active subscribers
    subscriber_count: usize,
}

/// Zero-copy streaming multiplexer for request deduplication
pub struct StreamingMultiplexer {
    /// Map of request hash to active broadcast channel senders
    active_streams: Arc<tokio::sync::Mutex<HashMap<RequestHash, ActiveStream>>>,
}

impl StreamingMultiplexer {
    /// Create a new streaming multiplexer
    pub fn new() -> Self {
        Self {
            active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
        }
    }

    /// Subscribe to a request's stream, deduplicating if identical request exists
    pub async fn subscribe(&self, request: &CompletionRequest) -> Result<MultiplexedStream> {
        let hash = RequestHash::from_request(request);
        let mut streams = self.active_streams.lock().await;

        match streams.get_mut(&hash) {
            Some(active) => {
                // Duplicate request found - reuse existing stream
                active.subscriber_count += 1;
                let receiver = active.sender.subscribe();
                Ok(MultiplexedStream { receiver })
            }
            None => {
                // New request - create broadcast channel
                let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
                streams.insert(
                    hash,
                    ActiveStream {
                        sender,
                        subscriber_count: 1,
                    },
                );
                Ok(MultiplexedStream { receiver })
            }
        }
    }

    /// Broadcast a chunk to all subscribers of a request
    pub async fn send_chunk(&self, request: &CompletionRequest, chunk: StreamChunk) -> Result<()> {
        let hash = RequestHash::from_request(request);
        let streams = self.active_streams.lock().await;

        if let Some(active) = streams.get(&hash) {
            // Wrap chunk in Arc for zero-copy sharing
            let arc_chunk = Arc::new(chunk);
            active
                .sender
                .send(arc_chunk)
                .map_err(|_| Error::InvalidRequest("Failed to broadcast chunk".to_string()))?;
        }

        Ok(())
    }

    /// Mark a request stream as complete (cleanup)
    pub async fn complete_request(&self, request: &CompletionRequest) {
        let hash = RequestHash::from_request(request);
        let mut streams = self.active_streams.lock().await;

        if let Some(mut active) = streams.remove(&hash) {
            active.subscriber_count -= 1;
            // Broadcast channel automatically closes when sender is dropped
            drop(active.sender);
        }
    }

    /// Get statistics about active streams
    pub async fn stats(&self) -> MultiplexerStats {
        let streams = self.active_streams.lock().await;
        let active_requests = streams.len();
        let total_subscribers: usize = streams.values().map(|s| s.subscriber_count).sum();

        MultiplexerStats {
            active_requests,
            total_subscribers,
        }
    }
}

impl Default for StreamingMultiplexer {
    fn default() -> Self {
        Self::new()
    }
}

/// Statistics about the multiplexer state
#[derive(Debug, Clone, Copy)]
pub struct MultiplexerStats {
    /// Number of unique active requests
    pub active_requests: usize,
    /// Total number of subscribers across all requests
    pub total_subscribers: usize,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::Message;

    fn create_test_request(model: &str) -> CompletionRequest {
        CompletionRequest::new(model, vec![Message::user("test")])
    }

    #[test]
    fn test_request_hash_same_request() {
        let req1 = create_test_request("openai/gpt-4");
        let req2 = create_test_request("openai/gpt-4");

        assert_eq!(
            RequestHash::from_request(&req1),
            RequestHash::from_request(&req2)
        );
    }

    #[test]
    fn test_request_hash_different_model() {
        let req1 = create_test_request("openai/gpt-4");
        let req2 = create_test_request("anthropic/claude-sonnet");

        assert_ne!(
            RequestHash::from_request(&req1),
            RequestHash::from_request(&req2)
        );
    }

    #[test]
    fn test_request_hash_different_message() {
        let mut req1 = create_test_request("openai/gpt-4");
        let req2 = create_test_request("openai/gpt-4");

        req1.messages.push(Message::user("extra"));

        assert_ne!(
            RequestHash::from_request(&req1),
            RequestHash::from_request(&req2)
        );
    }

    #[test]
    fn test_request_hash_different_temperature() {
        let mut req1 = create_test_request("openai/gpt-4");
        let mut req2 = create_test_request("openai/gpt-4");

        req1.temperature = Some(0.5);
        req2.temperature = Some(1.0);

        assert_ne!(
            RequestHash::from_request(&req1),
            RequestHash::from_request(&req2)
        );
    }

    #[tokio::test]
    async fn test_multiplexer_new() {
        let multiplexer = StreamingMultiplexer::new();
        let stats = multiplexer.stats().await;

        assert_eq!(stats.active_requests, 0);
        assert_eq!(stats.total_subscribers, 0);
    }

    #[tokio::test]
    async fn test_multiplexer_duplicate_detection() {
        let multiplexer = StreamingMultiplexer::new();
        let request = create_test_request("openai/gpt-4");

        // Create first subscriber
        let _stream1 = multiplexer.subscribe(&request).await.unwrap();
        let stats = multiplexer.stats().await;
        assert_eq!(stats.active_requests, 1);
        assert_eq!(stats.total_subscribers, 1);

        // Create second subscriber for same request
        let _stream2 = multiplexer.subscribe(&request).await.unwrap();
        let stats = multiplexer.stats().await;
        assert_eq!(stats.active_requests, 1);
        assert_eq!(stats.total_subscribers, 2);
    }

    #[tokio::test]
    async fn test_multiplexer_different_requests() {
        let multiplexer = StreamingMultiplexer::new();
        let req1 = create_test_request("openai/gpt-4");
        let req2 = create_test_request("anthropic/claude-sonnet");

        let _stream1 = multiplexer.subscribe(&req1).await.unwrap();
        let _stream2 = multiplexer.subscribe(&req2).await.unwrap();

        let stats = multiplexer.stats().await;
        assert_eq!(stats.active_requests, 2);
        assert_eq!(stats.total_subscribers, 2);
    }

    #[tokio::test]
    async fn test_multiplexer_broadcast_chunk() {
        let multiplexer = StreamingMultiplexer::new();
        let request = create_test_request("openai/gpt-4");

        // Create subscriber
        let mut stream = multiplexer.subscribe(&request).await.unwrap();

        // Send a chunk
        let chunk = StreamChunk {
            event_type: crate::types::StreamEventType::ContentBlockDelta,
            index: Some(0),
            delta: Some(crate::types::ContentDelta::Text {
                text: "hello".to_string(),
            }),
            stop_reason: None,
            usage: None,
        };

        multiplexer.send_chunk(&request, chunk).await.unwrap();

        // Verify we can receive it
        use futures::StreamExt;
        if let Some(Ok(received)) = stream.next().await {
            match received.delta {
                Some(crate::types::ContentDelta::Text { text }) => {
                    assert_eq!(text, "hello");
                }
                other => {
                    panic!("Expected text delta, got {:?}", other);
                }
            }
        } else {
            panic!("Failed to receive chunk from multiplexer");
        }
    }
}