1use crate::StreamingModelProvider;
2use crate::provider::LlmResponseStream;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use super::Context;
6
7#[doc = include_str!("docs/alloyed.md")]
8pub struct AlloyedModelProvider {
9 providers: Vec<Box<dyn StreamingModelProvider>>,
10 current_provider_index: AtomicUsize,
11}
12
13impl AlloyedModelProvider {
14 pub fn new(providers: Vec<Box<dyn StreamingModelProvider>>) -> Self {
15 Self { providers, current_provider_index: AtomicUsize::new(0) }
16 }
17
18 fn get_current_provider(&self) -> Option<&dyn StreamingModelProvider> {
19 if self.providers.is_empty() {
20 return None;
21 }
22 let index = self.current_provider_index.load(Ordering::Relaxed) % self.providers.len();
23 Some(self.providers[index].as_ref())
24 }
25
26 fn get_next_provider(&self) -> Option<&dyn StreamingModelProvider> {
27 if self.providers.is_empty() {
28 return None;
29 }
30 let index = self.current_provider_index.fetch_add(1, Ordering::Relaxed) % self.providers.len();
31 Some(self.providers[index].as_ref())
32 }
33}
34
35impl StreamingModelProvider for AlloyedModelProvider {
36 fn stream_response(&self, context: &Context) -> LlmResponseStream {
37 match self.get_next_provider() {
38 Some(provider) => provider.stream_response(context),
39 None => Box::pin(tokio_stream::empty()),
40 }
41 }
42
43 fn model(&self) -> Option<crate::LlmModel> {
44 self.get_current_provider().and_then(super::provider::StreamingModelProvider::model)
45 }
46
47 fn display_name(&self) -> String {
48 match self.get_current_provider() {
49 Some(provider) => provider.display_name(),
50 None => "Alloyed (no providers)".to_string(),
51 }
52 }
53
54 fn context_window(&self) -> Option<u32> {
55 if self.providers.is_empty() {
56 return None;
57 }
58
59 let mut min_context: Option<u32> = None;
60 for provider in &self.providers {
61 let context = provider.context_window()?;
62 min_context = Some(min_context.map_or(context, |current| current.min(context)));
63 }
64
65 min_context
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::LlmResponse;
73 use crate::testing::FakeLlmProvider;
74
75 struct FixedContextProvider {
76 context_window: Option<u32>,
77 }
78
79 impl StreamingModelProvider for FixedContextProvider {
80 fn stream_response(&self, _context: &Context) -> LlmResponseStream {
81 Box::pin(tokio_stream::iter(vec![Ok(LlmResponse::done())]))
82 }
83
84 fn display_name(&self) -> String {
85 "Fixed Context".to_string()
86 }
87
88 fn context_window(&self) -> Option<u32> {
89 self.context_window
90 }
91 }
92
93 #[test]
94 fn test_alloyed_provider_display_name_empty() {
95 let provider = AlloyedModelProvider::new(vec![]);
96 assert_eq!(provider.display_name(), "Alloyed (no providers)");
97 }
98
99 #[test]
100 fn test_alloyed_provider_display_name_single() {
101 let fake_provider = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
102 let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider)]);
103
104 assert_eq!(provider.display_name(), "Fake LLM");
106 }
107
108 #[test]
109 fn test_alloyed_provider_display_name_multiple() {
110 let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
111 let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
112 let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
113
114 assert_eq!(provider.display_name(), "Fake LLM"); assert_eq!(provider.display_name(), "Fake LLM"); }
118
119 #[test]
120 fn test_alloyed_provider_cycling() {
121 let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
122 let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
123 let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
124
125 let context = Context::new(vec![], vec![]);
126
127 let _stream1 = provider.stream_response(&context); let name1 = provider.display_name(); let _stream2 = provider.stream_response(&context); let name2 = provider.display_name(); let _stream3 = provider.stream_response(&context); let name3 = provider.display_name(); assert_eq!(name1, "Fake LLM");
139 assert_eq!(name2, "Fake LLM");
140 assert_eq!(name3, "Fake LLM");
141 }
142
143 #[test]
144 fn test_display_name_doesnt_advance_counter() {
145 let fake_provider1 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
146 let fake_provider2 = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
147 let provider = AlloyedModelProvider::new(vec![Box::new(fake_provider1), Box::new(fake_provider2)]);
148
149 let name1 = provider.display_name();
151 let name2 = provider.display_name();
152 let name3 = provider.display_name();
153
154 assert_eq!(name1, "Fake LLM");
155 assert_eq!(name2, "Fake LLM");
156 assert_eq!(name3, "Fake LLM");
157 }
158
159 #[test]
160 fn test_context_window_unknown_if_any_provider_unknown() {
161 let known = FixedContextProvider { context_window: Some(200_000) };
162 let unknown = FakeLlmProvider::new(vec![vec![LlmResponse::done()]]);
163 let provider = AlloyedModelProvider::new(vec![Box::new(known), Box::new(unknown)]);
164 assert_eq!(provider.context_window(), None);
165 }
166
167 #[test]
168 fn test_context_window_uses_min_of_known_providers() {
169 let p1 = FixedContextProvider { context_window: Some(200_000) };
170 let p2 = FixedContextProvider { context_window: Some(128_000) };
171 let provider = AlloyedModelProvider::new(vec![Box::new(p1), Box::new(p2)]);
172 assert_eq!(provider.context_window(), Some(128_000));
173 }
174}