ai_lib/provider/strategies/
round_robin.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2
3use async_trait::async_trait;
4use futures::stream::Stream;
5
6use crate::{
7    api::{ChatCompletionChunk, ChatProvider, ModelInfo},
8    types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse},
9};
10
11pub struct RoundRobinProvider {
12    name: String,
13    providers: Vec<Box<dyn ChatProvider>>,
14    cursor: AtomicUsize,
15}
16
17impl RoundRobinProvider {
18    pub fn new(providers: Vec<Box<dyn ChatProvider>>) -> Result<Self, AiLibError> {
19        if providers.is_empty() {
20            return Err(AiLibError::ConfigurationError(
21                "round_robin strategy requires at least one provider".to_string(),
22            ));
23        }
24
25        let composed_name = providers
26            .iter()
27            .map(|p| p.name().to_string())
28            .collect::<Vec<_>>()
29            .join(",");
30
31        Ok(Self {
32            name: format!("round_robin[{composed_name}]"),
33            providers,
34            cursor: AtomicUsize::new(0),
35        })
36    }
37
38    fn select(&self) -> &dyn ChatProvider {
39        let idx = self.cursor.fetch_add(1, Ordering::Relaxed) % self.providers.len();
40        self.providers[idx].as_ref()
41    }
42}
43
44#[async_trait]
45impl ChatProvider for RoundRobinProvider {
46    fn name(&self) -> &str {
47        &self.name
48    }
49
50    async fn chat(
51        &self,
52        request: ChatCompletionRequest,
53    ) -> Result<ChatCompletionResponse, AiLibError> {
54        self.select().chat(request).await
55    }
56
57    async fn stream(
58        &self,
59        request: ChatCompletionRequest,
60    ) -> Result<
61        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
62        AiLibError,
63    > {
64        self.select().stream(request).await
65    }
66
67    async fn batch(
68        &self,
69        requests: Vec<ChatCompletionRequest>,
70        concurrency_limit: Option<usize>,
71    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
72        self.select().batch(requests, concurrency_limit).await
73    }
74
75    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
76        self.select().list_models().await
77    }
78
79    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
80        for provider in &self.providers {
81            match provider.get_model_info(model_id).await {
82                Ok(info) => return Ok(info),
83                Err(err) => {
84                    if matches!(err, AiLibError::ModelNotFound(_)) {
85                        continue;
86                    }
87                    return Err(err);
88                }
89            }
90        }
91
92        Err(AiLibError::ModelNotFound(format!(
93            "model {model_id} not available in round robin chain"
94        )))
95    }
96}