use crate::AgentInfo;
use crate::agent_variant::AgentVariant;
use crate::error::RandAgentError;
use backon::{ExponentialBuilder, Retryable};
use rand::Rng;
use rig::completion::{CompletionError, Message, Prompt, PromptError};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
#[derive(Clone)]
pub struct RandAgent {
agents: Arc<Mutex<Vec<AgentState>>>,
on_agent_invalid: OnAgentInvalidCallback,
}
#[derive(Clone)]
pub struct AgentState {
pub id: i32,
pub agent: Arc<AgentVariant>,
pub info: AgentInfo,
}
impl Prompt for RandAgent {
#[allow(refining_impl_trait)]
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
let agent_index = self
.get_random_valid_agent_index()
.await
.ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
let mut agents = self.agents.lock().await;
let agent_state = &mut agents[agent_index];
tracing::info!(
"Using provider: {}, model: {},id: {}",
agent_state.info.provider,
agent_state.info.model,
agent_state.info.id
);
match agent_state.agent.prompt(prompt).await {
Ok(content) => {
agent_state.record_success();
Ok(content)
}
Err(e) => {
agent_state.record_failure();
if !agent_state.is_valid()
&& let Some(cb) = &self.on_agent_invalid
{
cb(agent_state.id);
}
Err(e)
}
}
}
}
impl AgentState {
fn new(
agent: AgentVariant,
id: i32,
provider: String,
model: String,
max_failures: u32,
) -> Self {
Self {
id,
agent: Arc::new(agent),
info: AgentInfo {
id,
provider,
model,
failure_count: 0,
max_failures,
},
}
}
fn is_valid(&self) -> bool {
self.info.failure_count < self.info.max_failures
}
fn record_failure(&mut self) {
self.info.failure_count += 1;
}
fn record_success(&mut self) {
self.info.failure_count = 0;
}
}
impl RandAgent {
pub fn new(agents: Vec<(AgentVariant, i32, String, String)>) -> Self {
Self::with_max_failures_and_callback(agents, 3, None)
}
pub fn with_max_failures_and_callback(
agents: Vec<(AgentVariant, i32, String, String)>,
max_failures: u32,
on_agent_invalid: OnAgentInvalidCallback,
) -> Self {
let agent_states = agents
.into_iter()
.map(|(agent, id, provider, model)| {
AgentState::new(agent, id, provider, model, max_failures)
})
.collect();
Self {
agents: Arc::new(Mutex::new(agent_states)),
on_agent_invalid,
}
}
pub fn with_max_failures(
agents: Vec<(AgentVariant, i32, String, String)>,
max_failures: u32,
) -> Self {
Self::with_max_failures_and_callback(agents, max_failures, None)
}
pub fn set_on_agent_invalid<F>(&mut self, callback: F)
where
F: Fn(i32) + Send + Sync + 'static,
{
self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
}
pub async fn add_agent(&self, agent: AgentVariant, id: i32, provider: String, model: String) {
let mut agents = self.agents.lock().await;
agents.push(AgentState::new(agent, id, provider, model, 3));
}
pub async fn add_agent_with_max_failures(
&self,
agent: AgentVariant,
id: i32,
provider: String,
model: String,
max_failures: u32,
) {
let mut agents = self.agents.lock().await;
agents.push(AgentState::new(agent, id, provider, model, max_failures));
}
pub async fn len(&self) -> usize {
let agents = self.agents.lock().await;
agents.iter().filter(|state| state.is_valid()).count()
}
pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
let agents = self.agents.lock().await;
let valid_indices: Vec<usize> = agents
.iter()
.enumerate()
.filter(|(_, state)| state.is_valid())
.map(|(i, _)| i)
.collect();
if valid_indices.is_empty() {
return None;
}
let mut rng = rand::rng();
let random_index = rng.random_range(0..valid_indices.len());
Some(valid_indices[random_index])
}
pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
let mut agents = self.agents.lock().await;
let valid_indices: Vec<usize> = agents
.iter()
.enumerate()
.filter(|(_, state)| state.is_valid())
.map(|(i, _)| i)
.collect();
if valid_indices.is_empty() {
return None;
}
let mut rng = rand::rng();
let random_index = rng.random_range(0..valid_indices.len());
let agent_index = valid_indices[random_index];
agents.get_mut(agent_index).cloned()
}
pub async fn total_len(&self) -> usize {
let agents = self.agents.lock().await;
agents.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
let agents = self.agents.lock().await;
let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
tracing::info!("agents info: {:?}", agent_infos);
agent_infos
}
pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
let agents = self.agents.lock().await;
agents
.iter()
.enumerate()
.map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
.collect()
}
pub async fn reset_failures(&self) {
let mut agents = self.agents.lock().await;
for state in agents.iter_mut() {
state.info.failure_count = 0;
}
}
pub async fn get_agent_by_name(
&self,
provider_name: &str,
model_name: &str,
) -> Option<AgentState> {
let mut agents = self.agents.lock().await;
for agent in agents.iter_mut() {
if agent.info.provider == provider_name && agent.info.model == model_name {
return Some(agent.clone());
}
}
None
}
pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
let mut agents = self.agents.lock().await;
for agent in agents.iter_mut() {
if agent.info.id == id {
return Some(agent.clone());
}
}
None
}
pub async fn try_invoke_with_retry(
&self,
info: Message,
retry_num: Option<usize>,
) -> Result<String, RandAgentError> {
let mut config = ExponentialBuilder::default();
if let Some(retry_num) = retry_num {
config = config.with_max_times(retry_num)
}
let info = Arc::new(info);
let content = (|| {
let agent = self.clone();
let prompt = info.clone();
async move { agent.prompt((*prompt).clone()).await }
})
.retry(config)
.sleep(tokio::time::sleep)
.notify(|err: &PromptError, dur: Duration| {
println!("retrying {err:?} after {dur:?}");
})
.await?;
Ok(content)
}
#[allow(refining_impl_trait)]
pub async fn prompt_with_info(
&self,
prompt: impl Into<Message> + Send,
) -> Result<(String, AgentInfo), PromptError> {
let agent_index = self
.get_random_valid_agent_index()
.await
.ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
let mut agents = self.agents.lock().await;
let agent_state = &mut agents[agent_index];
let agent_info = agent_state.info.clone();
tracing::info!(
"prompt_with_info Using provider: {}, model: {},id: {}",
agent_state.info.provider,
agent_state.info.model,
agent_state.info.id
);
match agent_state.agent.prompt(prompt).await {
Ok(content) => {
agent_state.record_success();
Ok((content, agent_info))
}
Err(e) => {
agent_state.record_failure();
if !agent_state.is_valid()
&& let Some(cb) = &self.on_agent_invalid
{
cb(agent_state.id);
}
Err(e)
}
}
}
pub async fn try_invoke_with_info_retry(
&self,
info: Message,
retry_num: Option<usize>,
) -> Result<(String, AgentInfo), RandAgentError> {
let mut config = ExponentialBuilder::default();
if let Some(retry_num) = retry_num {
config = config.with_max_times(retry_num)
}
let info = Arc::new(info);
let content = (|| {
let agent = self.clone();
let prompt = info.clone();
async move { agent.prompt_with_info((*prompt).clone()).await }
})
.retry(config)
.sleep(tokio::time::sleep)
.notify(|err: &PromptError, dur: Duration| {
println!("retrying {err:?} after {dur:?}");
})
.await?;
Ok(content)
}
}
pub struct RandAgentBuilder {
pub(crate) agents: Vec<(AgentVariant, i32, String, String)>,
max_failures: u32,
on_agent_invalid: OnAgentInvalidCallback,
}
impl RandAgentBuilder {
pub fn new() -> Self {
Self {
agents: Vec::new(),
max_failures: 3, on_agent_invalid: None,
}
}
pub fn max_failures(mut self, max_failures: u32) -> Self {
self.max_failures = max_failures;
self
}
pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
where
F: Fn(i32) + Send + Sync + 'static,
{
self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
self
}
pub fn add_agent(
mut self,
agent: AgentVariant,
id: i32,
provider_name: String,
model_name: String,
) -> Self {
self.agents.push((agent, id, provider_name, model_name));
self
}
pub fn build(self) -> RandAgent {
RandAgent::with_max_failures_and_callback(
self.agents,
self.max_failures,
self.on_agent_invalid,
)
}
}
impl Default for RandAgentBuilder {
fn default() -> Self {
Self::new()
}
}