use std::collections::HashSet;
#[cfg(all(feature = "col", feature = "cli"))]
use crate::agents::generic::GenericAgent;
#[cfg(all(feature = "col", feature = "cli"))]
use crate::cli::models::{default_model, provider_models};
#[cfg(all(feature = "col", feature = "cli"))]
use crate::common::utils::{ClientType, PROVIDER_ENV_MAP, PROVIDER_NAMES, Task};
#[cfg(all(feature = "col", feature = "cli"))]
use crate::tui::state::TuiEvent;
#[cfg(all(feature = "col", feature = "cli"))]
use std::env;
#[cfg(all(feature = "col", feature = "cli"))]
use std::sync::Arc;
#[cfg(all(feature = "col", feature = "cli"))]
use tokio::sync::Mutex;
#[cfg(all(feature = "col", feature = "cli"))]
use tokio::sync::mpsc::{Receiver, UnboundedSender};
#[cfg(feature = "col")]
#[derive(Debug, Clone)]
pub enum CollabSelection {
Random,
Explicit(String),
}
#[cfg(all(feature = "col", feature = "cli"))]
#[derive(Debug, Clone)]
pub struct ProviderSlot {
pub name: String,
pub models: Vec<String>,
pub model_cursor: usize,
pub failure_count: u8,
pub exhausted: bool,
}
#[cfg(all(feature = "col", feature = "cli"))]
impl ProviderSlot {
pub fn current_model(&self) -> &str {
self.models
.get(self.model_cursor)
.map(|s| s.as_str())
.unwrap_or("")
}
pub fn try_next_model(&mut self) -> bool {
if self.model_cursor + 1 < self.models.len() {
self.model_cursor += 1;
self.failure_count = 0;
true
} else {
self.exhausted = true;
false
}
}
}
#[cfg(feature = "col")]
#[derive(Clone, Debug)]
pub struct CollabPool {
#[cfg(feature = "cli")]
pub slots: Vec<(ProviderSlot, GenericAgent)>,
#[cfg(not(feature = "cli"))]
pub provider_names: Vec<String>,
pub cursor: usize,
pub exhausted: HashSet<usize>,
#[cfg(feature = "cli")]
pub event_tx: Option<UnboundedSender<TuiEvent>>,
}
#[cfg(feature = "col")]
impl CollabPool {
#[cfg(not(feature = "cli"))]
pub fn from_providers(provider_names: Vec<String>) -> Self {
Self {
provider_names,
cursor: 0,
exhausted: HashSet::new(),
}
}
pub fn len(&self) -> usize {
#[cfg(feature = "cli")]
return self.slots.len();
#[cfg(not(feature = "cli"))]
return self.provider_names.len();
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn provider_name(&self, idx: usize) -> &str {
#[cfg(feature = "cli")]
return self
.slots
.get(idx)
.map(|(s, _)| s.name.as_str())
.unwrap_or("");
#[cfg(not(feature = "cli"))]
return self
.provider_names
.get(idx)
.map(|s| s.as_str())
.unwrap_or("");
}
pub fn pick_start(&self, selection: &CollabSelection) -> usize {
let pool_size = self.len().max(1);
match selection {
CollabSelection::Random => {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as usize;
ts % pool_size
}
CollabSelection::Explicit(name) => {
for idx in 0..self.len() {
if self.provider_name(idx) == name.as_str() {
return idx;
}
}
0
}
}
}
}
#[cfg(all(feature = "col", feature = "cli"))]
impl CollabPool {
pub fn from_env(
persona: &str,
behavior: &str,
workspace: &str,
yolo: bool,
verbose: bool,
event_tx: Option<UnboundedSender<TuiEvent>>,
input_rx: Option<Arc<Mutex<Receiver<String>>>>,
) -> Self {
let mut slots = Vec::new();
for provider in PROVIDER_NAMES {
let env_key = match PROVIDER_ENV_MAP.get(provider).copied() {
Some(k) => k,
None => continue,
};
if env::var(env_key).is_err() {
continue;
}
let models: Vec<String> = provider_models(provider)
.into_iter()
.map(|m| m.id)
.collect();
let model = if models.is_empty() {
default_model(provider)
} else {
models[0].clone()
};
let models = if models.is_empty() {
vec![model.clone()]
} else {
models
};
unsafe {
env::set_var("AI_PROVIDER", provider);
}
if let Some(model_id) = models.first() {
let provider_upper = provider.to_uppercase();
unsafe {
env::set_var(format!("{}_MODEL", provider_upper), model_id);
}
}
let client = ClientType::from_env();
let mut agent = GenericAgent::default();
agent.agent.persona = persona.to_string().into();
agent.agent.behavior = behavior.to_string().into();
agent.yolo = yolo;
agent.workspace = workspace.to_string();
agent.model = models.first().cloned().unwrap_or_default();
agent.provider = provider.to_string();
agent.verbose = verbose;
agent.event_tx = event_tx.clone();
agent.input_rx = input_rx.clone();
agent.internet_access = true;
agent.client = client;
let slot = ProviderSlot {
name: provider.to_string(),
models,
model_cursor: 0,
failure_count: 0,
exhausted: false,
};
slots.push((slot, agent));
}
if let Some(tx) = &event_tx {
let provider_list: Vec<(String, String)> = slots
.iter()
.map(|(s, _)| (s.name.clone(), s.current_model().to_string()))
.collect();
let _ = tx.send(TuiEvent::CollabPool(provider_list));
}
Self {
slots,
cursor: 0,
exhausted: HashSet::new(),
event_tx,
}
}
pub fn next_available(&mut self) -> Option<usize> {
let len = self.slots.len();
if self.exhausted.len() >= len {
return None;
}
for _ in 0..len {
let idx = self.cursor % len;
self.cursor += 1;
if !self.exhausted.contains(&idx) {
return Some(idx);
}
}
None
}
pub fn mark_failure(&mut self, idx: usize) -> bool {
if let Some((slot, agent)) = self.slots.get_mut(idx) {
slot.failure_count += 1;
if slot.try_next_model() {
let new_model = slot.current_model().to_string();
agent.model = new_model.clone();
if let Some(tx) = &self.event_tx {
let _ = tx.send(TuiEvent::Log(format!(
"\u{1f504} [Collab] {} \u{2192} switching to model {}",
slot.name, new_model
)));
}
true
} else {
self.exhausted.insert(idx);
if let Some(tx) = &self.event_tx {
let _ = tx.send(TuiEvent::Log(format!(
"\u{26a0} [Collab] {} exhausted all models, removing from pool.",
slot.name
)));
}
false
}
} else {
false
}
}
pub fn distribute(&mut self, tasks: Vec<Task>, start_idx: usize) -> Vec<(usize, Task)> {
let len = self.slots.len();
tasks
.into_iter()
.enumerate()
.map(|(i, task)| {
let slot_idx = (start_idx + i) % len.max(1);
(slot_idx, task)
})
.collect()
}
pub fn agent_mut(&mut self, idx: usize) -> Option<&mut GenericAgent> {
self.slots.get_mut(idx).map(|(_, agent)| agent)
}
}