Skip to main content

llm/
alloyed.rs

1use crate::StreamingModelProvider;
2use crate::provider::LlmResponseStream;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use super::Context;
6
7#[doc = include_str!("docs/alloyed.md")]
8pub struct AlloyedModelProvider {
9    providers: Vec<Box<dyn StreamingModelProvider>>,
10    current_provider_index: AtomicUsize,
11}
12
13impl AlloyedModelProvider {
14    pub fn new(providers: Vec<Box<dyn StreamingModelProvider>>) -> Self {
15        Self { providers, current_provider_index: AtomicUsize::new(0) }
16    }
17
18    fn get_current_provider(&self) -> Option<&dyn StreamingModelProvider> {
19        if self.providers.is_empty() {
20            return None;
21        }
22        let index = self.current_provider_index.load(Ordering::Relaxed) % self.providers.len();
23        Some(self.providers[index].as_ref())
24    }
25
26    fn get_next_provider(&self) -> Option<&dyn StreamingModelProvider> {
27        if self.providers.is_empty() {
28            return None;
29        }
30        let index = self.current_provider_index.fetch_add(1, Ordering::Relaxed) % self.providers.len();
31        Some(self.providers[index].as_ref())
32    }
33}
34
35impl StreamingModelProvider for AlloyedModelProvider {
36    fn stream_response(&self, context: &Context) -> LlmResponseStream {
37        match self.get_next_provider() {
38            Some(provider) => provider.stream_response(context),
39            None => Box::pin(tokio_stream::empty()),
40        }
41    }
42
43    fn model(&self) -> Option<crate::LlmModel> {
44        self.get_current_provider().and_then(super::provider::StreamingModelProvider::model)
45    }
46
47    fn display_name(&self) -> String {
48        match self.get_current_provider() {
49            Some(provider) => provider.display_name(),
50            None => "Alloyed (no providers)".to_string(),
51        }
52    }
53
54    fn context_window(&self) -> Option<u32> {
55        if self.providers.is_empty() {
56            return None;
57        }
58
59        let mut min_context: Option<u32> = None;
60        for provider in &self.providers {
61            let context = provider.context_window()?;
62            min_context = Some(min_context.map_or(context, |current| current.min(context)));
63        }
64
65        min_context
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::LlmResponse;
73    use crate::testing::FakeLlmProvider;
74
75    struct FixedContextProvider {
76        context_window: Option<u32>,
77    }
78
79    impl StreamingModelProvider for FixedContextProvider {
80        fn stream_response(&self, _context: &Context) -> LlmResponseStream {
81            Box::pin(tokio_stream::iter(vec![Ok(LlmResponse::done())]))
82        }
83
84        fn display_name(&self) -> String {
85            "Fixed Context".to_string()
86        }
87
88        fn context_window(&self) -> Option<u32> {
89            self.context_window
90        }
91    }
92
93    #[test]
94    fn test_alloyed_provider_display_name_empty() {
95        let provider = AlloyedModelProvider::new(vec![]);
96        assert_eq!(provider.display_name(), "Alloyed (no providers)");
97    }
98
99    #[test]
100    fn test_alloyed_provider_display_name_single() {
101        let fake_provider = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
102        let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider)]);
103
104        // Should return the individual provider's display name
105        assert_eq!(provider.display_name(), "Fake LLM");
106    }
107
108    #[test]
109    fn test_alloyed_provider_display_name_multiple() {
110        let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
111        let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
112        let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
113
114        // Should cycle through individual provider names
115        assert_eq!(provider.display_name(), "Fake LLM"); // First call
116        assert_eq!(provider.display_name(), "Fake LLM"); // Second call (cycles back)
117    }
118
119    #[test]
120    fn test_alloyed_provider_cycling() {
121        let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
122        let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
123        let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
124
125        let context = Context::new(vec![], vec![]);
126
127        // stream_response should advance to next provider each time
128        let _stream1 = provider.stream_response(&context); // Uses provider 0, advances to 1
129        let name1 = provider.display_name(); // Should show provider 1
130
131        let _stream2 = provider.stream_response(&context); // Uses provider 1, advances to 0 (wraps)
132        let name2 = provider.display_name(); // Should show provider 0
133
134        let _stream3 = provider.stream_response(&context); // Uses provider 0, advances to 1
135        let name3 = provider.display_name(); // Should show provider 1
136
137        // All should return "Fake LLM" but they're cycling through different instances
138        assert_eq!(name1, "Fake LLM");
139        assert_eq!(name2, "Fake LLM");
140        assert_eq!(name3, "Fake LLM");
141    }
142
143    #[test]
144    fn test_display_name_doesnt_advance_counter() {
145        let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
146        let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
147        let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
148
149        // Calling display_name multiple times should return the same result
150        let name1 = provider.display_name();
151        let name2 = provider.display_name();
152        let name3 = provider.display_name();
153
154        assert_eq!(name1, "Fake LLM");
155        assert_eq!(name2, "Fake LLM");
156        assert_eq!(name3, "Fake LLM");
157    }
158
159    #[test]
160    fn test_context_window_unknown_if_any_provider_unknown() {
161        let known = FixedContextProvider { context_window: Some(200_000) };
162        let unknown = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
163        let provider = AlloyedModelProvider::new(vec![Box::new(known), Box::new(unknown)]);
164        assert_eq!(provider.context_window(), None);
165    }
166
167    #[test]
168    fn test_context_window_uses_min_of_known_providers() {
169        let p1 = FixedContextProvider { context_window: Some(200_000) };
170        let p2 = FixedContextProvider { context_window: Some(128_000) };
171        let provider = AlloyedModelProvider::new(vec![Box::new(p1), Box::new(p2)]);
172        assert_eq!(provider.context_window(), Some(128_000));
173    }
174}