Skip to main content

cc_switch/daemon/aggregate/
stream.rs

1use crate::daemon::aggregate::state::AliasMap;
2use ccs_proxy::CaptureEvent;
3use serde::Serialize;
4use std::sync::Arc;
5use tokio::sync::broadcast;
6use tokio_stream::StreamExt;
7use tokio_stream::wrappers::BroadcastStream;
8
9pub type ProxyEventReceiver = (String, broadcast::Receiver<CaptureEvent>);
10
11#[derive(Debug, Clone, Serialize)]
12pub struct TaggedCaptureEvent {
13    pub upstream: String,
14    pub aliases: Vec<String>,
15    #[serde(flatten)]
16    pub inner: CaptureEvent,
17}
18
19pub async fn event_merger(
20    proxy_events: Vec<ProxyEventReceiver>,
21    alias_map: Arc<AliasMap>,
22    merged_tx: broadcast::Sender<TaggedCaptureEvent>,
23) {
24    let streams: Vec<_> = proxy_events
25        .into_iter()
26        .map(|(upstream, rx)| {
27            let upstream = upstream.clone();
28            BroadcastStream::new(rx)
29                .filter_map(move |res| res.ok().map(|ev| (upstream.clone(), ev)))
30        })
31        .collect();
32
33    let mut merged = futures::stream::select_all(streams);
34
35    while let Some((upstream, event)) = merged.next().await {
36        let aliases = alias_map.aliases_for(&upstream);
37        let tagged = TaggedCaptureEvent {
38            upstream,
39            aliases,
40            inner: event,
41        };
42        let _ = merged_tx.send(tagged);
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use ccs_proxy::CaptureEvent;
50    use tokio::sync::broadcast;
51
52    #[tokio::test]
53    async fn merger_tags_events_with_upstream() {
54        let (tx_a, _) = broadcast::channel::<CaptureEvent>(16);
55        let (tx_b, _) = broadcast::channel::<CaptureEvent>(16);
56        let (merged_tx, mut merged_rx) = broadcast::channel::<TaggedCaptureEvent>(64);
57
58        let alias_map = Arc::new(AliasMap::from_entries(vec![(
59            "https://a.example.com".to_string(),
60            vec!["alias_a".to_string()],
61        )]));
62
63        let proxy_events = vec![
64            ("https://a.example.com".to_string(), tx_a.subscribe()),
65            ("https://b.example.com".to_string(), tx_b.subscribe()),
66        ];
67
68        let _merger = tokio::spawn(event_merger(proxy_events, alias_map, merged_tx));
69
70        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
71
72        tx_a.send(CaptureEvent::RequestStarted {
73            session_id: "sess1".to_string(),
74            seq: 1,
75            started_at: chrono::Utc::now(),
76            model: Some("claude-sonnet-4-6".to_string()),
77        })
78        .unwrap();
79
80        let tagged = tokio::time::timeout(std::time::Duration::from_secs(1), merged_rx.recv())
81            .await
82            .unwrap()
83            .unwrap();
84
85        assert_eq!(tagged.upstream, "https://a.example.com");
86        assert_eq!(tagged.aliases, vec!["alias_a"]);
87    }
88
89    #[tokio::test]
90    async fn merger_handles_unknown_upstream_aliases() {
91        let (tx_b, _) = broadcast::channel::<CaptureEvent>(16);
92        let (merged_tx, mut merged_rx) = broadcast::channel::<TaggedCaptureEvent>(64);
93
94        let alias_map = Arc::new(AliasMap::from_entries(vec![]));
95
96        let proxy_events = vec![("https://b.example.com".to_string(), tx_b.subscribe())];
97
98        let _merger = tokio::spawn(event_merger(proxy_events, alias_map, merged_tx));
99
100        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
101
102        tx_b.send(CaptureEvent::RequestStarted {
103            session_id: "sess2".to_string(),
104            seq: 1,
105            started_at: chrono::Utc::now(),
106            model: None,
107        })
108        .unwrap();
109
110        let tagged = tokio::time::timeout(std::time::Duration::from_secs(1), merged_rx.recv())
111            .await
112            .unwrap()
113            .unwrap();
114
115        assert_eq!(tagged.upstream, "https://b.example.com");
116        assert!(tagged.aliases.is_empty());
117    }
118}