Skip to main content

bamboo_server/
reloadable_provider.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5use bamboo_agent_core::{tools::ToolSchema, Message};
6use bamboo_infrastructure::provider::{LLMProvider, LLMRequestOptions, Result};
7use bamboo_infrastructure::LLMStream;
8
9/// An `LLMProvider` wrapper that always delegates to the latest provider stored in a shared lock.
10///
11/// This prevents stale provider snapshots after runtime config changes (provider/model/proxy),
12/// while keeping the call sites ergonomic (`Arc<dyn LLMProvider>`).
13pub struct ReloadableProvider {
14    inner: Arc<RwLock<Arc<dyn LLMProvider>>>,
15}
16
17impl ReloadableProvider {
18    pub fn new(inner: Arc<RwLock<Arc<dyn LLMProvider>>>) -> Self {
19        Self { inner }
20    }
21
22    async fn current(&self) -> Arc<dyn LLMProvider> {
23        self.inner.read().await.clone()
24    }
25}
26
27#[async_trait]
28impl LLMProvider for ReloadableProvider {
29    async fn chat_stream(
30        &self,
31        messages: &[Message],
32        tools: &[ToolSchema],
33        max_output_tokens: Option<u32>,
34        model: &str,
35    ) -> Result<LLMStream> {
36        let provider = self.current().await;
37        provider
38            .chat_stream(messages, tools, max_output_tokens, model)
39            .await
40    }
41
42    async fn chat_stream_with_options(
43        &self,
44        messages: &[Message],
45        tools: &[ToolSchema],
46        max_output_tokens: Option<u32>,
47        model: &str,
48        options: Option<&LLMRequestOptions>,
49    ) -> Result<LLMStream> {
50        let provider = self.current().await;
51        provider
52            .chat_stream_with_options(messages, tools, max_output_tokens, model, options)
53            .await
54    }
55
56    async fn list_models(&self) -> Result<Vec<String>> {
57        let provider = self.current().await;
58        provider.list_models().await
59    }
60}