use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use bamboo_agent_core::{tools::ToolSchema, Message};
use bamboo_infrastructure::provider::{LLMProvider, LLMRequestOptions, Result};
use bamboo_infrastructure::LLMStream;
pub struct ReloadableProvider {
inner: Arc<RwLock<Arc<dyn LLMProvider>>>,
}
impl ReloadableProvider {
pub fn new(inner: Arc<RwLock<Arc<dyn LLMProvider>>>) -> Self {
Self { inner }
}
async fn current(&self) -> Arc<dyn LLMProvider> {
self.inner.read().await.clone()
}
}
#[async_trait]
impl LLMProvider for ReloadableProvider {
async fn chat_stream(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream> {
let provider = self.current().await;
provider
.chat_stream(messages, tools, max_output_tokens, model)
.await
}
async fn chat_stream_with_options(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
let provider = self.current().await;
provider
.chat_stream_with_options(messages, tools, max_output_tokens, model, options)
.await
}
async fn list_models(&self) -> Result<Vec<String>> {
let provider = self.current().await;
provider.list_models().await
}
}