ai_lib/provider/strategies/
round_robin.rs1use std::sync::atomic::{AtomicUsize, Ordering};
2
3use async_trait::async_trait;
4use futures::stream::Stream;
5
6use crate::{
7 api::{ChatCompletionChunk, ChatProvider, ModelInfo},
8 types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse},
9};
10
11pub struct RoundRobinProvider {
12 name: String,
13 providers: Vec<Box<dyn ChatProvider>>,
14 cursor: AtomicUsize,
15}
16
17impl RoundRobinProvider {
18 pub fn new(providers: Vec<Box<dyn ChatProvider>>) -> Result<Self, AiLibError> {
19 if providers.is_empty() {
20 return Err(AiLibError::ConfigurationError(
21 "round_robin strategy requires at least one provider".to_string(),
22 ));
23 }
24
25 let composed_name = providers
26 .iter()
27 .map(|p| p.name().to_string())
28 .collect::<Vec<_>>()
29 .join(",");
30
31 Ok(Self {
32 name: format!("round_robin[{composed_name}]"),
33 providers,
34 cursor: AtomicUsize::new(0),
35 })
36 }
37
38 fn select(&self) -> &dyn ChatProvider {
39 let idx = self.cursor.fetch_add(1, Ordering::Relaxed) % self.providers.len();
40 self.providers[idx].as_ref()
41 }
42}
43
44#[async_trait]
45impl ChatProvider for RoundRobinProvider {
46 fn name(&self) -> &str {
47 &self.name
48 }
49
50 async fn chat(
51 &self,
52 request: ChatCompletionRequest,
53 ) -> Result<ChatCompletionResponse, AiLibError> {
54 self.select().chat(request).await
55 }
56
57 async fn stream(
58 &self,
59 request: ChatCompletionRequest,
60 ) -> Result<
61 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
62 AiLibError,
63 > {
64 self.select().stream(request).await
65 }
66
67 async fn batch(
68 &self,
69 requests: Vec<ChatCompletionRequest>,
70 concurrency_limit: Option<usize>,
71 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
72 self.select().batch(requests, concurrency_limit).await
73 }
74
75 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
76 self.select().list_models().await
77 }
78
79 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
80 for provider in &self.providers {
81 match provider.get_model_info(model_id).await {
82 Ok(info) => return Ok(info),
83 Err(err) => {
84 if matches!(err, AiLibError::ModelNotFound(_)) {
85 continue;
86 }
87 return Err(err);
88 }
89 }
90 }
91
92 Err(AiLibError::ModelNotFound(format!(
93 "model {model_id} not available in round robin chain"
94 )))
95 }
96}